diff --git a/config/src/main/java/org/springframework/security/config/annotation/web/configurers/saml2/Saml2LogoutConfigurer.java b/config/src/main/java/org/springframework/security/config/annotation/web/configurers/saml2/Saml2LogoutConfigurer.java index fe45f087f1e..87d486d8a74 100644 --- a/config/src/main/java/org/springframework/security/config/annotation/web/configurers/saml2/Saml2LogoutConfigurer.java +++ b/config/src/main/java/org/springframework/security/config/annotation/web/configurers/saml2/Saml2LogoutConfigurer.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2021 the original author or authors. + * Copyright 2002-2022 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -147,7 +147,7 @@ public Saml2LogoutConfigurer(ApplicationContext context) { *

* The Relying Party triggers logout by POSTing to the endpoint. The Asserting Party * triggers logout based on what is specified by - * {@link RelyingPartyRegistration#getSingleLogoutServiceBinding()}. + * {@link RelyingPartyRegistration#getSingleLogoutServiceBindings()}. * @param logoutUrl the URL that will invoke logout * @return the {@link LogoutConfigurer} for further customizations * @see LogoutConfigurer#logoutUrl(String) @@ -343,7 +343,7 @@ public final class LogoutRequestConfigurer { * *

* The Asserting Party should use whatever HTTP method specified in - * {@link RelyingPartyRegistration#getSingleLogoutServiceBinding()}. + * {@link RelyingPartyRegistration#getSingleLogoutServiceBindings()}. * @param logoutUrl the URL that will receive the SAML 2.0 Logout Request * @return the {@link LogoutRequestConfigurer} for further customizations * @see Saml2LogoutConfigurer#logoutUrl(String) @@ -425,7 +425,7 @@ public final class LogoutResponseConfigurer { * *

* The Asserting Party should use whatever HTTP method specified in - * {@link RelyingPartyRegistration#getSingleLogoutServiceBinding()}. + * {@link RelyingPartyRegistration#getSingleLogoutServiceBindings()}. * @param logoutUrl the URL that will receive the SAML 2.0 Logout Response * @return the {@link LogoutResponseConfigurer} for further customizations * @see Saml2LogoutConfigurer#logoutUrl(String) diff --git a/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/metadata/OpenSamlMetadataResolver.java b/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/metadata/OpenSamlMetadataResolver.java index 565b6547c73..4e0ad7f6a2b 100644 --- a/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/metadata/OpenSamlMetadataResolver.java +++ b/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/metadata/OpenSamlMetadataResolver.java @@ -46,6 +46,7 @@ import org.springframework.security.saml2.core.OpenSamlInitializationService; import org.springframework.security.saml2.core.Saml2X509Credential; import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration; +import org.springframework.security.saml2.provider.service.registration.Saml2MessageBinding; import org.springframework.util.Assert; /** @@ -104,7 +105,9 @@ private SPSSODescriptor buildSpSsoDescriptor(RelyingPartyRegistration registrati .addAll(buildKeys(registration.getDecryptionX509Credentials(), UsageType.ENCRYPTION)); spSsoDescriptor.getAssertionConsumerServices().add(buildAssertionConsumerService(registration)); if (registration.getSingleLogoutServiceLocation() != null) { - spSsoDescriptor.getSingleLogoutServices().add(buildSingleLogoutService(registration)); + for (Saml2MessageBinding binding : registration.getSingleLogoutServiceBindings()) { + spSsoDescriptor.getSingleLogoutServices().add(buildSingleLogoutService(registration, binding)); + } } if (registration.getNameIdFormat() != null) { spSsoDescriptor.getNameIDFormats().add(buildNameIDFormat(registration)); @@ -147,11 +150,12 @@ private AssertionConsumerService buildAssertionConsumerService(RelyingPartyRegis return assertionConsumerService; } - private SingleLogoutService buildSingleLogoutService(RelyingPartyRegistration registration) { + private SingleLogoutService buildSingleLogoutService(RelyingPartyRegistration registration, + Saml2MessageBinding binding) { SingleLogoutService singleLogoutService = build(SingleLogoutService.DEFAULT_ELEMENT_NAME); singleLogoutService.setLocation(registration.getSingleLogoutServiceLocation()); singleLogoutService.setResponseLocation(registration.getSingleLogoutServiceResponseLocation()); - singleLogoutService.setBinding(registration.getSingleLogoutServiceBinding().getUrn()); + singleLogoutService.setBinding(binding.getUrn()); return singleLogoutService; } diff --git a/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/registration/RelyingPartyRegistration.java b/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/registration/RelyingPartyRegistration.java index f238c83a242..2064e32aa20 100644 --- a/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/registration/RelyingPartyRegistration.java +++ b/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/registration/RelyingPartyRegistration.java @@ -28,6 +28,7 @@ import org.springframework.security.saml2.core.Saml2X509Credential; import org.springframework.util.Assert; +import org.springframework.util.CollectionUtils; /** * Represents a configured relying party (aka Service Provider) and asserting party (aka @@ -81,7 +82,7 @@ public final class RelyingPartyRegistration { private final String singleLogoutServiceResponseLocation; - private final Saml2MessageBinding singleLogoutServiceBinding; + private final Collection singleLogoutServiceBindings; private final String nameIdFormat; @@ -93,7 +94,7 @@ public final class RelyingPartyRegistration { private RelyingPartyRegistration(String registrationId, String entityId, String assertionConsumerServiceLocation, Saml2MessageBinding assertionConsumerServiceBinding, String singleLogoutServiceLocation, - String singleLogoutServiceResponseLocation, Saml2MessageBinding singleLogoutServiceBinding, + String singleLogoutServiceResponseLocation, Collection singleLogoutServiceBindings, AssertingPartyDetails assertingPartyDetails, String nameIdFormat, Collection decryptionX509Credentials, Collection signingX509Credentials) { @@ -101,8 +102,8 @@ private RelyingPartyRegistration(String registrationId, String entityId, String Assert.hasText(entityId, "entityId cannot be empty"); Assert.hasText(assertionConsumerServiceLocation, "assertionConsumerServiceLocation cannot be empty"); Assert.notNull(assertionConsumerServiceBinding, "assertionConsumerServiceBinding cannot be null"); - Assert.isTrue(singleLogoutServiceLocation == null || singleLogoutServiceBinding != null, - "singleLogoutServiceBinding cannot be null when singleLogoutServiceLocation is set"); + Assert.isTrue(singleLogoutServiceLocation == null || !CollectionUtils.isEmpty(singleLogoutServiceBindings), + "singleLogoutServiceBindings cannot be null or empty when singleLogoutServiceLocation is set"); Assert.notNull(assertingPartyDetails, "assertingPartyDetails cannot be null"); Assert.notNull(decryptionX509Credentials, "decryptionX509Credentials cannot be null"); for (Saml2X509Credential c : decryptionX509Credentials) { @@ -121,7 +122,7 @@ private RelyingPartyRegistration(String registrationId, String entityId, String this.assertionConsumerServiceBinding = assertionConsumerServiceBinding; this.singleLogoutServiceLocation = singleLogoutServiceLocation; this.singleLogoutServiceResponseLocation = singleLogoutServiceResponseLocation; - this.singleLogoutServiceBinding = singleLogoutServiceBinding; + this.singleLogoutServiceBindings = Collections.unmodifiableList(new LinkedList<>(singleLogoutServiceBindings)); this.nameIdFormat = nameIdFormat; this.assertingPartyDetails = assertingPartyDetails; this.decryptionX509Credentials = Collections.unmodifiableList(new LinkedList<>(decryptionX509Credentials)); @@ -194,7 +195,22 @@ public Saml2MessageBinding getAssertionConsumerServiceBinding() { * @since 5.6 */ public Saml2MessageBinding getSingleLogoutServiceBinding() { - return this.singleLogoutServiceBinding; + Assert.state(this.singleLogoutServiceBindings.size() == 1, "Method does not support multiple bindings."); + return this.singleLogoutServiceBindings.iterator().next(); + } + + /** + * Get the SingleLogoutService + * Binding + *

+ * Equivalent to the value found in <SingleLogoutService Binding="..."/> in the + * relying party's <SPSSODescriptor>. + * @return the SingleLogoutService Binding + * @since 5.8 + */ + public Collection getSingleLogoutServiceBindings() { + return this.singleLogoutServiceBindings; } /** @@ -308,7 +324,7 @@ public static Builder withRelyingPartyRegistration(RelyingPartyRegistration regi .assertionConsumerServiceBinding(registration.getAssertionConsumerServiceBinding()) .singleLogoutServiceLocation(registration.getSingleLogoutServiceLocation()) .singleLogoutServiceResponseLocation(registration.getSingleLogoutServiceResponseLocation()) - .singleLogoutServiceBinding(registration.getSingleLogoutServiceBinding()) + .singleLogoutServiceBindings((c) -> c.addAll(registration.getSingleLogoutServiceBindings())) .nameIdFormat(registration.getNameIdFormat()) .assertingPartyDetails((assertingParty) -> assertingParty .entityId(registration.getAssertingPartyDetails().getEntityId()) @@ -737,7 +753,7 @@ public static final class Builder { private String singleLogoutServiceResponseLocation; - private Saml2MessageBinding singleLogoutServiceBinding = Saml2MessageBinding.POST; + private Collection singleLogoutServiceBindings = new LinkedHashSet<>(); private String nameIdFormat = null; @@ -855,7 +871,28 @@ public Builder assertionConsumerServiceBinding(Saml2MessageBinding assertionCons * @since 5.6 */ public Builder singleLogoutServiceBinding(Saml2MessageBinding singleLogoutServiceBinding) { - this.singleLogoutServiceBinding = singleLogoutServiceBinding; + return this.singleLogoutServiceBindings((saml2MessageBindings) -> { + saml2MessageBindings.clear(); + saml2MessageBindings.add(singleLogoutServiceBinding); + }); + } + + /** + * Apply this {@link Consumer} to the {@link Collection} of + * {@link Saml2MessageBinding}s for the purposes of modifying the SingleLogoutService + * Binding {@link Collection}. + * + *

+ * Equivalent to the value found in <SingleLogoutService Binding="..."/> in + * the relying party's <SPSSODescriptor>. + * @param bindingsConsumer - the {@link Consumer} for modifying the + * {@link Collection} + * @return the {@link Builder} for further configuration + * @since 5.8 + */ + public Builder singleLogoutServiceBindings(Consumer> bindingsConsumer) { + bindingsConsumer.accept(this.singleLogoutServiceBindings); return this; } @@ -925,10 +962,15 @@ public RelyingPartyRegistration build() { if (this.singleLogoutServiceResponseLocation == null) { this.singleLogoutServiceResponseLocation = this.singleLogoutServiceLocation; } + + if (this.singleLogoutServiceBindings.isEmpty()) { + this.singleLogoutServiceBindings.add(Saml2MessageBinding.POST); + } + return new RelyingPartyRegistration(this.registrationId, this.entityId, this.assertionConsumerServiceLocation, this.assertionConsumerServiceBinding, this.singleLogoutServiceLocation, this.singleLogoutServiceResponseLocation, - this.singleLogoutServiceBinding, this.assertingPartyDetailsBuilder.build(), this.nameIdFormat, + this.singleLogoutServiceBindings, this.assertingPartyDetailsBuilder.build(), this.nameIdFormat, this.decryptionX509Credentials, this.signingX509Credentials); } diff --git a/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/web/authentication/logout/OpenSamlLogoutResponseResolver.java b/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/web/authentication/logout/OpenSamlLogoutResponseResolver.java index aad6eae0473..11cc276e29f 100644 --- a/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/web/authentication/logout/OpenSamlLogoutResponseResolver.java +++ b/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/web/authentication/logout/OpenSamlLogoutResponseResolver.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2021 the original author or authors. + * Copyright 2002-2022 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -134,9 +134,7 @@ Saml2LogoutResponse resolve(HttpServletRequest request, Authentication authentic if (registration.getAssertingPartyDetails().getSingleLogoutServiceResponseLocation() == null) { return null; } - String serialized = request.getParameter(Saml2ParameterNames.SAML_REQUEST); - byte[] b = Saml2Utils.samlDecode(serialized); - LogoutRequest logoutRequest = parse(inflateIfRequired(registration, b)); + LogoutRequest logoutRequest = parse(extractSamlRequest(request)); LogoutResponse logoutResponse = this.logoutResponseBuilder.buildObject(); logoutResponse.setDestination(registration.getAssertingPartyDetails().getSingleLogoutServiceResponseLocation()); Issuer issuer = this.issuerBuilder.buildObject(); @@ -189,8 +187,10 @@ private String getRegistrationId(Authentication authentication) { return null; } - private String inflateIfRequired(RelyingPartyRegistration registration, byte[] b) { - if (registration.getSingleLogoutServiceBinding() == Saml2MessageBinding.REDIRECT) { + private String extractSamlRequest(HttpServletRequest request) { + String serialized = request.getParameter(Saml2ParameterNames.SAML_REQUEST); + byte[] b = Saml2Utils.samlDecode(serialized); + if (Saml2MessageBindingUtils.isHttpRedirectBinding(request)) { return Saml2Utils.samlInflate(b); } return new String(b, StandardCharsets.UTF_8); diff --git a/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/web/authentication/logout/Saml2LogoutRequestFilter.java b/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/web/authentication/logout/Saml2LogoutRequestFilter.java index 92b6e3da6b7..cd5f7d9363e 100644 --- a/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/web/authentication/logout/Saml2LogoutRequestFilter.java +++ b/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/web/authentication/logout/Saml2LogoutRequestFilter.java @@ -122,7 +122,9 @@ protected void doFilterInternal(HttpServletRequest request, HttpServletResponse response.sendError(HttpServletResponse.SC_UNAUTHORIZED); return; } - if (!isCorrectBinding(request, registration)) { + + Saml2MessageBinding saml2MessageBinding = Saml2MessageBindingUtils.resolveBinding(request); + if (!registration.getSingleLogoutServiceBindings().contains(saml2MessageBinding)) { this.logger.trace("Did not process logout request since used incorrect binding"); response.sendError(HttpServletResponse.SC_UNAUTHORIZED); return; @@ -131,8 +133,7 @@ protected void doFilterInternal(HttpServletRequest request, HttpServletResponse String serialized = request.getParameter(Saml2ParameterNames.SAML_REQUEST); Saml2LogoutRequest logoutRequest = Saml2LogoutRequest.withRelyingPartyRegistration(registration) .samlRequest(serialized).relayState(request.getParameter(Saml2ParameterNames.RELAY_STATE)) - .binding(registration.getSingleLogoutServiceBinding()) - .location(registration.getSingleLogoutServiceLocation()) + .binding(saml2MessageBinding).location(registration.getSingleLogoutServiceLocation()) .parameters((params) -> params.put(Saml2ParameterNames.SIG_ALG, request.getParameter(Saml2ParameterNames.SIG_ALG))) .parameters((params) -> params.put(Saml2ParameterNames.SIGNATURE, @@ -177,14 +178,6 @@ private String getRegistrationId(Authentication authentication) { return null; } - private boolean isCorrectBinding(HttpServletRequest request, RelyingPartyRegistration registration) { - Saml2MessageBinding requiredBinding = registration.getSingleLogoutServiceBinding(); - if (requiredBinding == Saml2MessageBinding.POST) { - return "POST".equals(request.getMethod()); - } - return "GET".equals(request.getMethod()); - } - private void doRedirect(HttpServletRequest request, HttpServletResponse response, Saml2LogoutResponse logoutResponse) throws IOException { String location = logoutResponse.getResponseLocation(); diff --git a/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/web/authentication/logout/Saml2LogoutResponseFilter.java b/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/web/authentication/logout/Saml2LogoutResponseFilter.java index 0520a6e4b28..52bc2d4f156 100644 --- a/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/web/authentication/logout/Saml2LogoutResponseFilter.java +++ b/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/web/authentication/logout/Saml2LogoutResponseFilter.java @@ -125,8 +125,10 @@ protected void doFilterInternal(HttpServletRequest request, HttpServletResponse response.sendError(HttpServletResponse.SC_UNAUTHORIZED); return; } - if (!isCorrectBinding(request, registration)) { - this.logger.trace("Did not process logout request since used incorrect binding"); + + Saml2MessageBinding saml2MessageBinding = Saml2MessageBindingUtils.resolveBinding(request); + if (!registration.getSingleLogoutServiceBindings().contains(saml2MessageBinding)) { + this.logger.trace("Did not process logout response since used incorrect binding"); response.sendError(HttpServletResponse.SC_UNAUTHORIZED); return; } @@ -134,8 +136,7 @@ protected void doFilterInternal(HttpServletRequest request, HttpServletResponse String serialized = request.getParameter(Saml2ParameterNames.SAML_RESPONSE); Saml2LogoutResponse logoutResponse = Saml2LogoutResponse.withRelyingPartyRegistration(registration) .samlResponse(serialized).relayState(request.getParameter(Saml2ParameterNames.RELAY_STATE)) - .binding(registration.getSingleLogoutServiceBinding()) - .location(registration.getSingleLogoutServiceResponseLocation()) + .binding(saml2MessageBinding).location(registration.getSingleLogoutServiceResponseLocation()) .parameters((params) -> params.put(Saml2ParameterNames.SIG_ALG, request.getParameter(Saml2ParameterNames.SIG_ALG))) .parameters((params) -> params.put(Saml2ParameterNames.SIGNATURE, @@ -167,12 +168,4 @@ public void setLogoutRequestRepository(Saml2LogoutRequestRepository logoutReques this.logoutRequestRepository = logoutRequestRepository; } - private boolean isCorrectBinding(HttpServletRequest request, RelyingPartyRegistration registration) { - Saml2MessageBinding requiredBinding = registration.getSingleLogoutServiceBinding(); - if (requiredBinding == Saml2MessageBinding.POST) { - return "POST".equals(request.getMethod()); - } - return "GET".equals(request.getMethod()); - } - } diff --git a/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/web/authentication/logout/Saml2MessageBindingUtils.java b/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/web/authentication/logout/Saml2MessageBindingUtils.java new file mode 100644 index 00000000000..ebd19bca595 --- /dev/null +++ b/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/web/authentication/logout/Saml2MessageBindingUtils.java @@ -0,0 +1,60 @@ +/* + * Copyright 2002-2022 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.security.saml2.provider.service.web.authentication.logout; + +import jakarta.servlet.http.HttpServletRequest; + +import org.springframework.security.saml2.Saml2Exception; +import org.springframework.security.saml2.core.Saml2ParameterNames; +import org.springframework.security.saml2.provider.service.registration.Saml2MessageBinding; + +/** + * Utility methods for working with {@link Saml2MessageBinding} + * + * For internal use only. + * + * @since 5.8 + */ +final class Saml2MessageBindingUtils { + + private Saml2MessageBindingUtils() { + } + + static Saml2MessageBinding resolveBinding(HttpServletRequest request) { + if (isHttpPostBinding(request)) { + return Saml2MessageBinding.POST; + } + else if (isHttpRedirectBinding(request)) { + return Saml2MessageBinding.REDIRECT; + } + throw new Saml2Exception("Unable to determine message binding from request."); + } + + private static boolean isSamlRequestResponse(HttpServletRequest request) { + return (request.getParameter(Saml2ParameterNames.SAML_REQUEST) != null + || request.getParameter(Saml2ParameterNames.SAML_RESPONSE) != null); + } + + static boolean isHttpRedirectBinding(HttpServletRequest request) { + return request != null && "GET".equalsIgnoreCase(request.getMethod()) && isSamlRequestResponse(request); + } + + static boolean isHttpPostBinding(HttpServletRequest request) { + return request != null && "POST".equalsIgnoreCase(request.getMethod()) && isSamlRequestResponse(request); + } + +}