Skip to content

Commit

Permalink
Support symmetric key for JwtDecoder
Browse files Browse the repository at this point in the history
Fixes gh-5465
  • Loading branch information
jgrandja committed Apr 12, 2019
1 parent fc6b66f commit bed3371
Show file tree
Hide file tree
Showing 15 changed files with 1,024 additions and 161 deletions.
Expand Up @@ -15,23 +15,30 @@
*/
package org.springframework.security.oauth2.client.oidc.authentication;

import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.function.Function;

import org.springframework.security.oauth2.client.registration.ClientRegistration;
import org.springframework.security.oauth2.core.OAuth2AuthenticationException;
import org.springframework.security.oauth2.core.OAuth2Error;
import org.springframework.security.oauth2.core.OAuth2TokenValidator;
import org.springframework.security.oauth2.core.oidc.OidcIdToken;
import org.springframework.security.oauth2.jose.jws.JwsAlgorithm;
import org.springframework.security.oauth2.jose.jws.MacAlgorithm;
import org.springframework.security.oauth2.jose.jws.SignatureAlgorithm;
import org.springframework.security.oauth2.jwt.Jwt;
import org.springframework.security.oauth2.jwt.JwtDecoder;
import org.springframework.security.oauth2.jwt.JwtDecoderFactory;
import org.springframework.security.oauth2.jwt.NimbusJwtDecoder;
import org.springframework.util.Assert;
import org.springframework.util.StringUtils;

import javax.crypto.spec.SecretKeySpec;
import java.nio.charset.StandardCharsets;
import java.util.HashMap;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.function.Function;

import static org.springframework.security.oauth2.jwt.NimbusJwtDecoder.withJwkSetUri;
import static org.springframework.security.oauth2.jwt.NimbusJwtDecoder.withSecretKey;

/**
* A {@link JwtDecoderFactory factory} that provides a {@link JwtDecoder}
Expand All @@ -47,14 +54,45 @@
*/
public final class OidcIdTokenDecoderFactory implements JwtDecoderFactory<ClientRegistration> {
private static final String MISSING_SIGNATURE_VERIFIER_ERROR_CODE = "missing_signature_verifier";
private static Map<JwsAlgorithm, String> jcaAlgorithmMappings = new HashMap<JwsAlgorithm, String>() {
{
put(MacAlgorithm.HS256, "HmacSHA256");
put(MacAlgorithm.HS384, "HmacSHA384");
put(MacAlgorithm.HS512, "HmacSHA512");
}
};
private final Map<String, JwtDecoder> jwtDecoders = new ConcurrentHashMap<>();
private Function<ClientRegistration, OAuth2TokenValidator<Jwt>> jwtValidatorFactory = OidcIdTokenValidator::new;
private Function<ClientRegistration, JwsAlgorithm> jwsAlgorithmResolver = clientRegistration -> SignatureAlgorithm.RS256;

@Override
public JwtDecoder createDecoder(ClientRegistration clientRegistration) {
Assert.notNull(clientRegistration, "clientRegistration cannot be null");
return this.jwtDecoders.computeIfAbsent(clientRegistration.getRegistrationId(), key -> {
if (!StringUtils.hasText(clientRegistration.getProviderDetails().getJwkSetUri())) {
NimbusJwtDecoder jwtDecoder = buildDecoder(clientRegistration);
OAuth2TokenValidator<Jwt> jwtValidator = this.jwtValidatorFactory.apply(clientRegistration);
jwtDecoder.setJwtValidator(jwtValidator);
return jwtDecoder;
});
}

private NimbusJwtDecoder buildDecoder(ClientRegistration clientRegistration) {
JwsAlgorithm jwsAlgorithm = this.jwsAlgorithmResolver.apply(clientRegistration);
if (jwsAlgorithm != null && SignatureAlgorithm.class.isAssignableFrom(jwsAlgorithm.getClass())) {
// https://openid.net/specs/openid-connect-core-1_0.html#IDTokenValidation
//
// 6. If the ID Token is received via direct communication between the Client
// and the Token Endpoint (which it is in this flow),
// the TLS server validation MAY be used to validate the issuer in place of checking the token signature.
// The Client MUST validate the signature of all other ID Tokens according to JWS [JWS]
// using the algorithm specified in the JWT alg Header Parameter.
// The Client MUST use the keys provided by the Issuer.
//
// 7. The alg value SHOULD be the default of RS256 or the algorithm sent by the Client
// in the id_token_signed_response_alg parameter during Registration.

String jwkSetUri = clientRegistration.getProviderDetails().getJwkSetUri();
if (!StringUtils.hasText(jwkSetUri)) {
OAuth2Error oauth2Error = new OAuth2Error(
MISSING_SIGNATURE_VERIFIER_ERROR_CODE,
"Failed to find a Signature Verifier for Client Registration: '" +
Expand All @@ -64,12 +102,42 @@ public JwtDecoder createDecoder(ClientRegistration clientRegistration) {
);
throw new OAuth2AuthenticationException(oauth2Error, oauth2Error.toString());
}
String jwkSetUri = clientRegistration.getProviderDetails().getJwkSetUri();
NimbusJwtDecoder jwtDecoder = withJwkSetUri(jwkSetUri).build();
OAuth2TokenValidator<Jwt> jwtValidator = this.jwtValidatorFactory.apply(clientRegistration);
jwtDecoder.setJwtValidator(jwtValidator);
return jwtDecoder;
});
return withJwkSetUri(jwkSetUri).jwsAlgorithm(jwsAlgorithm).build();
} else if (jwsAlgorithm != null && MacAlgorithm.class.isAssignableFrom(jwsAlgorithm.getClass())) {
// https://openid.net/specs/openid-connect-core-1_0.html#IDTokenValidation
//
// 8. If the JWT alg Header Parameter uses a MAC based algorithm such as HS256, HS384, or HS512,
// the octets of the UTF-8 representation of the client_secret
// corresponding to the client_id contained in the aud (audience) Claim
// are used as the key to validate the signature.
// For MAC based algorithms, the behavior is unspecified if the aud is multi-valued or
// if an azp value is present that is different than the aud value.

String clientSecret = clientRegistration.getClientSecret();
if (!StringUtils.hasText(clientSecret)) {
OAuth2Error oauth2Error = new OAuth2Error(
MISSING_SIGNATURE_VERIFIER_ERROR_CODE,
"Failed to find a Signature Verifier for Client Registration: '" +
clientRegistration.getRegistrationId() +
"'. Check to ensure you have configured the client secret.",
null
);
throw new OAuth2AuthenticationException(oauth2Error, oauth2Error.toString());
}
SecretKeySpec secretKeySpec = new SecretKeySpec(
clientSecret.getBytes(StandardCharsets.UTF_8), jcaAlgorithmMappings.get(jwsAlgorithm));
return withSecretKey(secretKeySpec).macAlgorithm((MacAlgorithm) jwsAlgorithm).build();
}

OAuth2Error oauth2Error = new OAuth2Error(
MISSING_SIGNATURE_VERIFIER_ERROR_CODE,
"Failed to find a Signature Verifier for Client Registration: '" +
clientRegistration.getRegistrationId() +
"'. Check to ensure you have configured a valid JWS Algorithm: '" +
jwsAlgorithm + "'",
null
);
throw new OAuth2AuthenticationException(oauth2Error, oauth2Error.toString());
}

/**
Expand All @@ -82,4 +150,17 @@ public final void setJwtValidatorFactory(Function<ClientRegistration, OAuth2Toke
Assert.notNull(jwtValidatorFactory, "jwtValidatorFactory cannot be null");
this.jwtValidatorFactory = jwtValidatorFactory;
}

/**
* Sets the resolver that provides the expected {@link JwsAlgorithm JWS algorithm}
* used for the signature or MAC on the {@link OidcIdToken ID Token}.
* The default resolves to {@link SignatureAlgorithm#RS256 RS256} for all {@link ClientRegistration clients}.
*
* @param jwsAlgorithmResolver the resolver that provides the expected {@link JwsAlgorithm JWS algorithm}
* for a specific {@link ClientRegistration client}
*/
public final void setJwsAlgorithmResolver(Function<ClientRegistration, JwsAlgorithm> jwsAlgorithmResolver) {
Assert.notNull(jwsAlgorithmResolver, "jwsAlgorithmResolver cannot be null");
this.jwsAlgorithmResolver = jwsAlgorithmResolver;
}
}
Expand Up @@ -20,17 +20,26 @@
import org.springframework.security.oauth2.core.OAuth2Error;
import org.springframework.security.oauth2.core.OAuth2TokenValidator;
import org.springframework.security.oauth2.core.oidc.OidcIdToken;
import org.springframework.security.oauth2.jose.jws.JwsAlgorithm;
import org.springframework.security.oauth2.jose.jws.MacAlgorithm;
import org.springframework.security.oauth2.jose.jws.SignatureAlgorithm;
import org.springframework.security.oauth2.jwt.Jwt;
import org.springframework.security.oauth2.jwt.NimbusReactiveJwtDecoder;
import org.springframework.security.oauth2.jwt.ReactiveJwtDecoder;
import org.springframework.security.oauth2.jwt.ReactiveJwtDecoderFactory;
import org.springframework.util.Assert;
import org.springframework.util.StringUtils;

import javax.crypto.spec.SecretKeySpec;
import java.nio.charset.StandardCharsets;
import java.util.HashMap;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.function.Function;

import static org.springframework.security.oauth2.jwt.NimbusReactiveJwtDecoder.withJwkSetUri;
import static org.springframework.security.oauth2.jwt.NimbusReactiveJwtDecoder.withSecretKey;

/**
* A {@link ReactiveJwtDecoderFactory factory} that provides a {@link ReactiveJwtDecoder}
* used for {@link OidcIdToken} signature verification.
Expand All @@ -45,14 +54,45 @@
*/
public final class ReactiveOidcIdTokenDecoderFactory implements ReactiveJwtDecoderFactory<ClientRegistration> {
private static final String MISSING_SIGNATURE_VERIFIER_ERROR_CODE = "missing_signature_verifier";
private static Map<JwsAlgorithm, String> jcaAlgorithmMappings = new HashMap<JwsAlgorithm, String>() {
{
put(MacAlgorithm.HS256, "HmacSHA256");
put(MacAlgorithm.HS384, "HmacSHA384");
put(MacAlgorithm.HS512, "HmacSHA512");
}
};
private final Map<String, ReactiveJwtDecoder> jwtDecoders = new ConcurrentHashMap<>();
private Function<ClientRegistration, OAuth2TokenValidator<Jwt>> jwtValidatorFactory = OidcIdTokenValidator::new;
private Function<ClientRegistration, JwsAlgorithm> jwsAlgorithmResolver = clientRegistration -> SignatureAlgorithm.RS256;

@Override
public ReactiveJwtDecoder createDecoder(ClientRegistration clientRegistration) {
Assert.notNull(clientRegistration, "clientRegistration cannot be null");
return this.jwtDecoders.computeIfAbsent(clientRegistration.getRegistrationId(), key -> {
if (!StringUtils.hasText(clientRegistration.getProviderDetails().getJwkSetUri())) {
NimbusReactiveJwtDecoder jwtDecoder = buildDecoder(clientRegistration);
OAuth2TokenValidator<Jwt> jwtValidator = this.jwtValidatorFactory.apply(clientRegistration);
jwtDecoder.setJwtValidator(jwtValidator);
return jwtDecoder;
});
}

private NimbusReactiveJwtDecoder buildDecoder(ClientRegistration clientRegistration) {
JwsAlgorithm jwsAlgorithm = this.jwsAlgorithmResolver.apply(clientRegistration);
if (jwsAlgorithm != null && SignatureAlgorithm.class.isAssignableFrom(jwsAlgorithm.getClass())) {
// https://openid.net/specs/openid-connect-core-1_0.html#IDTokenValidation
//
// 6. If the ID Token is received via direct communication between the Client
// and the Token Endpoint (which it is in this flow),
// the TLS server validation MAY be used to validate the issuer in place of checking the token signature.
// The Client MUST validate the signature of all other ID Tokens according to JWS [JWS]
// using the algorithm specified in the JWT alg Header Parameter.
// The Client MUST use the keys provided by the Issuer.
//
// 7. The alg value SHOULD be the default of RS256 or the algorithm sent by the Client
// in the id_token_signed_response_alg parameter during Registration.

String jwkSetUri = clientRegistration.getProviderDetails().getJwkSetUri();
if (!StringUtils.hasText(jwkSetUri)) {
OAuth2Error oauth2Error = new OAuth2Error(
MISSING_SIGNATURE_VERIFIER_ERROR_CODE,
"Failed to find a Signature Verifier for Client Registration: '" +
Expand All @@ -62,12 +102,42 @@ public ReactiveJwtDecoder createDecoder(ClientRegistration clientRegistration) {
);
throw new OAuth2AuthenticationException(oauth2Error, oauth2Error.toString());
}
NimbusReactiveJwtDecoder jwtDecoder = new NimbusReactiveJwtDecoder(
clientRegistration.getProviderDetails().getJwkSetUri());
OAuth2TokenValidator<Jwt> jwtValidator = this.jwtValidatorFactory.apply(clientRegistration);
jwtDecoder.setJwtValidator(jwtValidator);
return jwtDecoder;
});
return withJwkSetUri(jwkSetUri).jwsAlgorithm(jwsAlgorithm).build();
} else if (jwsAlgorithm != null && MacAlgorithm.class.isAssignableFrom(jwsAlgorithm.getClass())) {
// https://openid.net/specs/openid-connect-core-1_0.html#IDTokenValidation
//
// 8. If the JWT alg Header Parameter uses a MAC based algorithm such as HS256, HS384, or HS512,
// the octets of the UTF-8 representation of the client_secret
// corresponding to the client_id contained in the aud (audience) Claim
// are used as the key to validate the signature.
// For MAC based algorithms, the behavior is unspecified if the aud is multi-valued or
// if an azp value is present that is different than the aud value.

String clientSecret = clientRegistration.getClientSecret();
if (!StringUtils.hasText(clientSecret)) {
OAuth2Error oauth2Error = new OAuth2Error(
MISSING_SIGNATURE_VERIFIER_ERROR_CODE,
"Failed to find a Signature Verifier for Client Registration: '" +
clientRegistration.getRegistrationId() +
"'. Check to ensure you have configured the client secret.",
null
);
throw new OAuth2AuthenticationException(oauth2Error, oauth2Error.toString());
}
SecretKeySpec secretKeySpec = new SecretKeySpec(
clientSecret.getBytes(StandardCharsets.UTF_8), jcaAlgorithmMappings.get(jwsAlgorithm));
return withSecretKey(secretKeySpec).macAlgorithm((MacAlgorithm) jwsAlgorithm).build();
}

OAuth2Error oauth2Error = new OAuth2Error(
MISSING_SIGNATURE_VERIFIER_ERROR_CODE,
"Failed to find a Signature Verifier for Client Registration: '" +
clientRegistration.getRegistrationId() +
"'. Check to ensure you have configured a valid JWS Algorithm: '" +
jwsAlgorithm + "'",
null
);
throw new OAuth2AuthenticationException(oauth2Error, oauth2Error.toString());
}

/**
Expand All @@ -80,4 +150,17 @@ public final void setJwtValidatorFactory(Function<ClientRegistration, OAuth2Toke
Assert.notNull(jwtValidatorFactory, "jwtValidatorFactory cannot be null");
this.jwtValidatorFactory = jwtValidatorFactory;
}

/**
* Sets the resolver that provides the expected {@link JwsAlgorithm JWS algorithm}
* used for the signature or MAC on the {@link OidcIdToken ID Token}.
* The default resolves to {@link SignatureAlgorithm#RS256 RS256} for all {@link ClientRegistration clients}.
*
* @param jwsAlgorithmResolver the resolver that provides the expected {@link JwsAlgorithm JWS algorithm}
* for a specific {@link ClientRegistration client}
*/
public final void setJwsAlgorithmResolver(Function<ClientRegistration, JwsAlgorithm> jwsAlgorithmResolver) {
Assert.notNull(jwsAlgorithmResolver, "jwsAlgorithmResolver cannot be null");
this.jwsAlgorithmResolver = jwsAlgorithmResolver;
}
}

0 comments on commit bed3371

Please sign in to comment.