Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,11 @@ public Authentication authenticate(Authentication authentication) throws Authent

boolean authenticatedCredentials = false;

if (!registeredClient.getClientAuthenticationMethods().contains(
clientAuthentication.getClientAuthenticationMethod())) {
throwInvalidClient();
}

if (clientAuthentication.getCredentials() != null) {
String clientSecret = clientAuthentication.getCredentials().toString();
// TODO Use PasswordEncoder.matches()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import org.springframework.lang.Nullable;
import org.springframework.security.authentication.AbstractAuthenticationToken;
import org.springframework.security.core.Authentication;
import org.springframework.security.oauth2.core.ClientAuthenticationMethod;
import org.springframework.security.oauth2.server.authorization.Version;
import org.springframework.security.oauth2.server.authorization.client.RegisteredClient;
import org.springframework.util.Assert;
Expand All @@ -30,6 +31,7 @@
*
* @author Joe Grandja
* @author Patryk Kostrzewa
* @author Anoop Garlapati
* @since 0.0.1
* @see AbstractAuthenticationToken
* @see RegisteredClient
Expand All @@ -39,6 +41,7 @@ public class OAuth2ClientAuthenticationToken extends AbstractAuthenticationToken
private static final long serialVersionUID = Version.SERIAL_VERSION_UID;
private String clientId;
private String clientSecret;
private ClientAuthenticationMethod clientAuthenticationMethod;
private Map<String, Object> additionalParameters;
private RegisteredClient registeredClient;

Expand All @@ -47,13 +50,17 @@ public class OAuth2ClientAuthenticationToken extends AbstractAuthenticationToken
*
* @param clientId the client identifier
* @param clientSecret the client secret
* @param clientAuthenticationMethod the authentication method used by the client
* @param additionalParameters the additional parameters
*/
public OAuth2ClientAuthenticationToken(String clientId, String clientSecret,
ClientAuthenticationMethod clientAuthenticationMethod,
@Nullable Map<String, Object> additionalParameters) {
this(clientId, additionalParameters);
Assert.hasText(clientSecret, "clientSecret cannot be empty");
Assert.notNull(clientAuthenticationMethod, "clientAuthenticationMethod cannot be null");
this.clientSecret = clientSecret;
this.clientAuthenticationMethod = clientAuthenticationMethod;
}

/**
Expand All @@ -69,6 +76,7 @@ public OAuth2ClientAuthenticationToken(String clientId,
this.clientId = clientId;
this.additionalParameters = additionalParameters != null ?
Collections.unmodifiableMap(additionalParameters) : null;
this.clientAuthenticationMethod = ClientAuthenticationMethod.NONE;
}

/**
Expand Down Expand Up @@ -112,4 +120,13 @@ public Object getCredentials() {
public @Nullable RegisteredClient getRegisteredClient() {
return this.registeredClient;
}

/**
* Returns the {@link ClientAuthenticationMethod client authentication method}.
*
* @return the {@link ClientAuthenticationMethod}
*/
public @Nullable ClientAuthenticationMethod getClientAuthenticationMethod() {
return this.clientAuthenticationMethod;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

import org.springframework.http.HttpHeaders;
import org.springframework.security.core.Authentication;
import org.springframework.security.oauth2.core.ClientAuthenticationMethod;
import org.springframework.security.oauth2.core.OAuth2AuthenticationException;
import org.springframework.security.oauth2.core.OAuth2Error;
import org.springframework.security.oauth2.core.OAuth2ErrorCodes;
Expand Down Expand Up @@ -85,7 +86,8 @@ public Authentication convert(HttpServletRequest request) {
throw new OAuth2AuthenticationException(new OAuth2Error(OAuth2ErrorCodes.INVALID_REQUEST), ex);
}

return new OAuth2ClientAuthenticationToken(clientID, clientSecret, extractAdditionalParameters(request));
return new OAuth2ClientAuthenticationToken(clientID, clientSecret, ClientAuthenticationMethod.BASIC,
extractAdditionalParameters(request));
}

private static Map<String, Object> extractAdditionalParameters(HttpServletRequest request) {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
/*
* Copyright 2020 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.security.oauth2.server.authorization.web;

import org.springframework.security.core.Authentication;
import org.springframework.security.oauth2.core.ClientAuthenticationMethod;
import org.springframework.security.oauth2.core.OAuth2AuthenticationException;
import org.springframework.security.oauth2.core.OAuth2Error;
import org.springframework.security.oauth2.core.OAuth2ErrorCodes;
import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames;
import org.springframework.security.oauth2.server.authorization.authentication.OAuth2ClientAuthenticationToken;
import org.springframework.security.web.authentication.AuthenticationConverter;
import org.springframework.util.MultiValueMap;
import org.springframework.util.StringUtils;

import javax.servlet.http.HttpServletRequest;
import java.util.Collections;
import java.util.HashMap;
import java.util.Map;

/**
* Attempts to extract client credentials from POST parameters of {@link HttpServletRequest}
* and then converts to an {@link OAuth2ClientAuthenticationToken} used for authenticating the client.
*
* @author Anoop Garlapati
* @since 0.1.0
* @see OAuth2ClientAuthenticationToken
* @see OAuth2ClientAuthenticationFilter
* @see <a target="_blank" href="https://tools.ietf.org/html/rfc6749#section-2.3.1">Section 2.3.1 Client Password</a>
*/
public class ClientSecretPostAuthenticationConverter implements AuthenticationConverter {

@Override
public Authentication convert(HttpServletRequest request) {
MultiValueMap<String, String> parameters = OAuth2EndpointUtils.getParameters(request);

// client_id (REQUIRED)
String clientId = parameters.getFirst(OAuth2ParameterNames.CLIENT_ID);
if (!StringUtils.hasText(clientId)) {
return null;
}

if (parameters.get(OAuth2ParameterNames.CLIENT_ID).size() != 1) {
throw new OAuth2AuthenticationException(new OAuth2Error(OAuth2ErrorCodes.INVALID_REQUEST));
}

// client_secret (REQUIRED)
String clientSecret = parameters.getFirst(OAuth2ParameterNames.CLIENT_SECRET);
if (!StringUtils.hasText(clientSecret)) {
return null;
}

if (parameters.get(OAuth2ParameterNames.CLIENT_SECRET).size() != 1) {
throw new OAuth2AuthenticationException(new OAuth2Error(OAuth2ErrorCodes.INVALID_REQUEST));
}

return new OAuth2ClientAuthenticationToken(clientId, clientSecret, ClientAuthenticationMethod.POST,
extractAdditionalParameters(request));
}

private static Map<String, Object> extractAdditionalParameters(HttpServletRequest request) {
Map<String, Object> additionalParameters = Collections.emptyMap();
if (OAuth2EndpointUtils.matchesPkceTokenRequest(request)) {
// Confidential clients can also leverage PKCE
additionalParameters = new HashMap<>(OAuth2EndpointUtils.getParameters(request).toSingleValueMap());
additionalParameters.remove(OAuth2ParameterNames.CLIENT_ID);
additionalParameters.remove(OAuth2ParameterNames.CLIENT_SECRET);
}
return additionalParameters;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ public OAuth2ClientAuthenticationFilter(AuthenticationManager authenticationMana
this.authenticationConverter = new DelegatingAuthenticationConverter(
Arrays.asList(
new ClientSecretBasicAuthenticationConverter(),
new ClientSecretPostAuthenticationConverter(),
new PublicClientAuthenticationConverter()));
this.authenticationSuccessHandler = this::onAuthenticationSuccess;
this.authenticationFailureHandler = this::onAuthenticationFailure;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -181,9 +181,7 @@ public void requestWhenTokenRequestValidThenReturnAccessTokenResponse() throws E
public void requestWhenPublicClientWithPkceThenReturnAccessTokenResponse() throws Exception {
this.spring.register(AuthorizationServerConfiguration.class).autowire();

RegisteredClient registeredClient = TestRegisteredClients.registeredClient()
.clientSecret(null)
.clientSettings(clientSettings -> clientSettings.requireProofKey(true))
RegisteredClient registeredClient = TestRegisteredClients.registeredPublicClient()
.tokenSettings(tokenSettings -> tokenSettings.enableRefreshTokens(false))
.build();
when(registeredClientRepository.findByClientId(eq(registeredClient.getClientId())))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import org.junit.Test;
import org.mockito.ArgumentCaptor;
import org.springframework.security.authentication.TestingAuthenticationToken;
import org.springframework.security.oauth2.core.ClientAuthenticationMethod;
import org.springframework.security.oauth2.core.OAuth2AuthenticationException;
import org.springframework.security.oauth2.core.OAuth2ErrorCodes;
import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest;
Expand Down Expand Up @@ -119,7 +120,7 @@ public void authenticateWhenClientPrincipalNotOAuth2ClientAuthenticationTokenThe
public void authenticateWhenClientPrincipalNotAuthenticatedThenThrowOAuth2AuthenticationException() {
RegisteredClient registeredClient = TestRegisteredClients.registeredClient().build();
OAuth2ClientAuthenticationToken clientPrincipal = new OAuth2ClientAuthenticationToken(
registeredClient.getClientId(), registeredClient.getClientSecret(), null);
registeredClient.getClientId(), registeredClient.getClientSecret(), ClientAuthenticationMethod.BASIC, null);
OAuth2AuthorizationCodeAuthenticationToken authentication =
new OAuth2AuthorizationCodeAuthenticationToken(AUTHORIZATION_CODE, clientPrincipal, null, null);
assertThatThrownBy(() -> this.authenticationProvider.authenticate(authentication))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import org.junit.Before;
import org.junit.Test;
import org.springframework.security.oauth2.core.AuthorizationGrantType;
import org.springframework.security.oauth2.core.ClientAuthenticationMethod;
import org.springframework.security.oauth2.core.OAuth2AuthenticationException;
import org.springframework.security.oauth2.core.OAuth2ErrorCodes;
import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames;
Expand Down Expand Up @@ -45,6 +46,7 @@
* @author Patryk Kostrzewa
* @author Joe Grandja
* @author Daniel Garnier-Moiroux
* @author Anoop Garlapati
*/
public class OAuth2ClientAuthenticationProviderTests {
private static final String PLAIN_CODE_VERIFIER = "pkce-key";
Expand Down Expand Up @@ -95,7 +97,7 @@ public void authenticateWhenInvalidClientIdThenThrowOAuth2AuthenticationExceptio
.thenReturn(registeredClient);

OAuth2ClientAuthenticationToken authentication = new OAuth2ClientAuthenticationToken(
registeredClient.getClientId() + "-invalid", registeredClient.getClientSecret(), null);
registeredClient.getClientId() + "-invalid", registeredClient.getClientSecret(), ClientAuthenticationMethod.BASIC, null);
assertThatThrownBy(() -> this.authenticationProvider.authenticate(authentication))
.isInstanceOf(OAuth2AuthenticationException.class)
.extracting(ex -> ((OAuth2AuthenticationException) ex).getError())
Expand All @@ -110,7 +112,7 @@ public void authenticateWhenInvalidClientSecretThenThrowOAuth2AuthenticationExce
.thenReturn(registeredClient);

OAuth2ClientAuthenticationToken authentication = new OAuth2ClientAuthenticationToken(
registeredClient.getClientId(), registeredClient.getClientSecret() + "-invalid", null);
registeredClient.getClientId(), registeredClient.getClientSecret() + "-invalid", ClientAuthenticationMethod.BASIC, null);
assertThatThrownBy(() -> this.authenticationProvider.authenticate(authentication))
.isInstanceOf(OAuth2AuthenticationException.class)
.extracting(ex -> ((OAuth2AuthenticationException) ex).getError())
Expand Down Expand Up @@ -140,7 +142,7 @@ public void authenticateWhenValidCredentialsThenAuthenticated() {
.thenReturn(registeredClient);

OAuth2ClientAuthenticationToken authentication = new OAuth2ClientAuthenticationToken(
registeredClient.getClientId(), registeredClient.getClientSecret(), null);
registeredClient.getClientId(), registeredClient.getClientSecret(), ClientAuthenticationMethod.BASIC, null);
OAuth2ClientAuthenticationToken authenticationResult =
(OAuth2ClientAuthenticationToken) this.authenticationProvider.authenticate(authentication);
assertThat(authenticationResult.isAuthenticated()).isTrue();
Expand Down Expand Up @@ -275,7 +277,7 @@ public void authenticateWhenPkceAndS256MethodAndInvalidCodeVerifierThenThrowOAut

@Test
public void authenticateWhenPkceAndPlainMethodAndValidCodeVerifierThenAuthenticated() {
RegisteredClient registeredClient = TestRegisteredClients.registeredClient().build();
RegisteredClient registeredClient = TestRegisteredClients.registeredPublicClient().build();
when(this.registeredClientRepository.findByClientId(eq(registeredClient.getClientId())))
.thenReturn(registeredClient);

Expand All @@ -300,7 +302,7 @@ public void authenticateWhenPkceAndPlainMethodAndValidCodeVerifierThenAuthentica

@Test
public void authenticateWhenPkceAndMissingMethodThenDefaultPlainMethodAndAuthenticated() {
RegisteredClient registeredClient = TestRegisteredClients.registeredClient().build();
RegisteredClient registeredClient = TestRegisteredClients.registeredPublicClient().build();
when(this.registeredClientRepository.findByClientId(eq(registeredClient.getClientId())))
.thenReturn(registeredClient);

Expand All @@ -327,7 +329,7 @@ public void authenticateWhenPkceAndMissingMethodThenDefaultPlainMethodAndAuthent

@Test
public void authenticateWhenPkceAndS256MethodAndValidCodeVerifierThenAuthenticated() {
RegisteredClient registeredClient = TestRegisteredClients.registeredClient().build();
RegisteredClient registeredClient = TestRegisteredClients.registeredPublicClient().build();
when(this.registeredClientRepository.findByClientId(eq(registeredClient.getClientId())))
.thenReturn(registeredClient);

Expand All @@ -352,7 +354,7 @@ public void authenticateWhenPkceAndS256MethodAndValidCodeVerifierThenAuthenticat

@Test
public void authenticateWhenPkceAndUnsupportedCodeChallengeMethodThenThrowOAuth2AuthenticationException() {
RegisteredClient registeredClient = TestRegisteredClients.registeredClient().build();
RegisteredClient registeredClient = TestRegisteredClients.registeredPublicClient().build();
when(this.registeredClientRepository.findByClientId(eq(registeredClient.getClientId())))
.thenReturn(registeredClient);

Expand All @@ -377,6 +379,21 @@ public void authenticateWhenPkceAndUnsupportedCodeChallengeMethodThenThrowOAuth2
.isEqualTo(OAuth2ErrorCodes.SERVER_ERROR);
}

@Test
public void authenticateWhenClientAuthenticationWithUnregisteredClientAuthenticationMethodThenThrowOAuth2AuthenticationException() {
RegisteredClient registeredClient = TestRegisteredClients.registeredClient().build();
when(this.registeredClientRepository.findByClientId(eq(registeredClient.getClientId())))
.thenReturn(registeredClient);

OAuth2ClientAuthenticationToken authentication = new OAuth2ClientAuthenticationToken(
registeredClient.getClientId(), registeredClient.getClientSecret(), ClientAuthenticationMethod.POST, null);
assertThatThrownBy(() -> this.authenticationProvider.authenticate(authentication))
.isInstanceOf(OAuth2AuthenticationException.class)
.extracting(ex -> ((OAuth2AuthenticationException) ex).getError())
.extracting("errorCode")
.isEqualTo(OAuth2ErrorCodes.INVALID_CLIENT);
}

private static Map<String, Object> createPkceTokenParameters(String codeVerifier) {
Map<String, Object> parameters = new HashMap<>();
parameters.put(OAuth2ParameterNames.GRANT_TYPE, AuthorizationGrantType.AUTHORIZATION_CODE.getValue());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
package org.springframework.security.oauth2.server.authorization.authentication;

import org.junit.Test;
import org.springframework.security.oauth2.core.ClientAuthenticationMethod;
import org.springframework.security.oauth2.server.authorization.client.RegisteredClient;
import org.springframework.security.oauth2.server.authorization.client.TestRegisteredClients;

Expand All @@ -29,23 +30,31 @@
* Tests for {@link OAuth2ClientAuthenticationToken}.
*
* @author Joe Grandja
* @author Anoop Garlapati
*/
public class OAuth2ClientAuthenticationTokenTests {

@Test
public void constructorWhenClientIdNullThenThrowIllegalArgumentException() {
assertThatThrownBy(() -> new OAuth2ClientAuthenticationToken(null, "secret", null))
assertThatThrownBy(() -> new OAuth2ClientAuthenticationToken(null, "secret", ClientAuthenticationMethod.BASIC, null))
.isInstanceOf(IllegalArgumentException.class)
.hasMessage("clientId cannot be empty");
}

@Test
public void constructorWhenClientSecretNullThenThrowIllegalArgumentException() {
assertThatThrownBy(() -> new OAuth2ClientAuthenticationToken("clientId", null, null))
assertThatThrownBy(() -> new OAuth2ClientAuthenticationToken("clientId", null, ClientAuthenticationMethod.BASIC, null))
.isInstanceOf(IllegalArgumentException.class)
.hasMessage("clientSecret cannot be empty");
}

@Test
public void constructorWhenClientAuthenticationMethodNullThenThrowIllegalArgumentException() {
assertThatThrownBy(() -> new OAuth2ClientAuthenticationToken("clientId", "clientSecret", null, null))
.isInstanceOf(IllegalArgumentException.class)
.hasMessage("clientAuthenticationMethod cannot be null");
}

@Test
public void constructorWhenRegisteredClientNullThenThrowIllegalArgumentException() {
assertThatThrownBy(() -> new OAuth2ClientAuthenticationToken(null))
Expand All @@ -55,11 +64,13 @@ public void constructorWhenRegisteredClientNullThenThrowIllegalArgumentException

@Test
public void constructorWhenClientCredentialsProvidedThenCreated() {
OAuth2ClientAuthenticationToken authentication = new OAuth2ClientAuthenticationToken("clientId", "secret", null);
OAuth2ClientAuthenticationToken authentication = new OAuth2ClientAuthenticationToken("clientId", "secret",
ClientAuthenticationMethod.BASIC, null);
assertThat(authentication.isAuthenticated()).isFalse();
assertThat(authentication.getPrincipal().toString()).isEqualTo("clientId");
assertThat(authentication.getCredentials()).isEqualTo("secret");
assertThat(authentication.getRegisteredClient()).isNull();
assertThat(authentication.getClientAuthenticationMethod()).isEqualTo(ClientAuthenticationMethod.BASIC);
}

@Test
Expand All @@ -72,6 +83,7 @@ public void constructorWhenClientIdProvidedThenCreated() {
assertThat(authentication.getCredentials()).isNull();
assertThat(authentication.getAdditionalParameters()).isEqualTo(additionalParameters);
assertThat(authentication.getRegisteredClient()).isNull();
assertThat(authentication.getClientAuthenticationMethod()).isEqualTo(ClientAuthenticationMethod.NONE);
}

@Test
Expand All @@ -83,4 +95,15 @@ public void constructorWhenRegisteredClientProvidedThenCreated() {
assertThat(authentication.getCredentials()).isNull();
assertThat(authentication.getRegisteredClient()).isEqualTo(registeredClient);
}

@Test
public void constructorWhenClientCredentialsAndClientAuthenticationMethodProvidedThenCreated() {
OAuth2ClientAuthenticationToken authentication = new OAuth2ClientAuthenticationToken("clientId", "secret",
ClientAuthenticationMethod.BASIC, null);
assertThat(authentication.isAuthenticated()).isFalse();
assertThat(authentication.getPrincipal().toString()).isEqualTo("clientId");
assertThat(authentication.getCredentials()).isEqualTo("secret");
assertThat(authentication.getRegisteredClient()).isNull();
assertThat(authentication.getClientAuthenticationMethod()).isEqualTo(ClientAuthenticationMethod.BASIC);
}
}
Loading