Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -217,6 +217,7 @@ public void getWhenUsingDefaultsInLambdaWithValidBearerTokenThenAcceptsRequest()
public void getWhenUsingJwkSetUriThenAcceptsRequest() throws Exception {
this.spring.register(WebServerConfig.class, JwkSetUriConfig.class, BasicController.class).autowire();
mockWebServer(jwks("Default"));
mockWebServer(jwks("Default"));
String token = this.token("ValidNoScopes");

this.mvc.perform(get("/").with(bearerToken(token)))
Expand All @@ -228,6 +229,7 @@ public void getWhenUsingJwkSetUriThenAcceptsRequest() throws Exception {
public void getWhenUsingJwkSetUriInLambdaThenAcceptsRequest() throws Exception {
this.spring.register(WebServerConfig.class, JwkSetUriInLambdaConfig.class, BasicController.class).autowire();
mockWebServer(jwks("Default"));
mockWebServer(jwks("Default"));
String token = this.token("ValidNoScopes");

this.mvc.perform(get("/").with(bearerToken(token)))
Expand Down Expand Up @@ -1398,6 +1400,7 @@ public void getWhenMultipleIssuersThenUsesIssuerClaimToDifferentiate() throws Ex

mockWebServer(String.format(metadata, issuerOne, issuerOne));
mockWebServer(jwkSet);
mockWebServer(jwkSet);

this.mvc.perform(get("/authenticated")
.with(bearerToken(jwtOne)))
Expand All @@ -1406,6 +1409,7 @@ public void getWhenMultipleIssuersThenUsesIssuerClaimToDifferentiate() throws Ex

mockWebServer(String.format(metadata, issuerTwo, issuerTwo));
mockWebServer(jwkSet);
mockWebServer(jwkSet);

this.mvc.perform(get("/authenticated")
.with(bearerToken(jwtTwo)))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,7 @@ public void getWhenValidBearerTokenThenAcceptsRequest() throws Exception {
public void getWhenUsingJwkSetUriThenAcceptsRequest() throws Exception {
this.spring.configLocations(xml("WebServer"), xml("JwkSetUri")).autowire();
mockWebServer(jwks("Default"));
mockWebServer(jwks("Default"));
String token = this.token("ValidNoScopes");

this.mvc.perform(get("/")
Expand Down Expand Up @@ -834,20 +835,23 @@ public void getWhenMultipleIssuersThenUsesIssuerClaimToDifferentiate() throws Ex

mockWebServer(String.format(metadata, issuerOne, issuerOne));
mockWebServer(jwkSet);
mockWebServer(jwkSet);

this.mvc.perform(get("/authenticated")
.header("Authorization", "Bearer " + jwtOne))
.andExpect(status().isNotFound());

mockWebServer(String.format(metadata, issuerTwo, issuerTwo));
mockWebServer(jwkSet);
mockWebServer(jwkSet);

this.mvc.perform(get("/authenticated")
.header("Authorization", "Bearer " + jwtTwo))
.andExpect(status().isNotFound());

mockWebServer(String.format(metadata, issuerThree, issuerThree));
mockWebServer(jwkSet);
mockWebServer(jwkSet);

this.mvc.perform(get("/authenticated")
.header("Authorization", "Bearer " + jwtThree))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -235,6 +235,7 @@ public void getWhenUsingJwkSetUriThenConsultsAccordingly() {

MockWebServer mockWebServer = this.spring.getContext().getBean(MockWebServer.class);
mockWebServer.enqueue(new MockResponse().setBody(this.jwkSet));
mockWebServer.enqueue(new MockResponse().setBody(this.jwkSet));

this.client.get()
.headers(headers -> headers.setBearerAuth(this.messageReadTokenWithKid))
Expand All @@ -248,6 +249,7 @@ public void getWhenUsingJwkSetUriInLambdaThenConsultsAccordingly() {

MockWebServer mockWebServer = this.spring.getContext().getBean(MockWebServer.class);
mockWebServer.enqueue(new MockResponse().setBody(this.jwkSet));
mockWebServer.enqueue(new MockResponse().setBody(this.jwkSet));

this.client.get()
.headers(headers -> headers.setBearerAuth(this.messageReadTokenWithKid))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,7 @@ class ServerJwtDslTests {
fun `jwt when using custom JWK Set URI then custom URI used`() {
this.spring.register(CustomJwkSetUriConfig::class.java).autowire()

CustomJwkSetUriConfig.MOCK_WEB_SERVER.enqueue(MockResponse().setBody(jwkSet))
CustomJwkSetUriConfig.MOCK_WEB_SERVER.enqueue(MockResponse().setBody(jwkSet))

this.client.get()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,24 +16,13 @@

package org.springframework.security.oauth2.jwt;

import java.io.IOException;
import java.net.MalformedURLException;
import java.net.URL;
import java.security.interfaces.RSAPublicKey;
import java.text.ParseException;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashSet;
import java.util.LinkedHashMap;
import java.util.Map;
import java.util.Set;
import java.util.function.Consumer;
import javax.crypto.SecretKey;

import com.nimbusds.jose.Algorithm;
import com.nimbusds.jose.JOSEException;
import com.nimbusds.jose.JWSAlgorithm;
import com.nimbusds.jose.RemoteKeySourceException;
import com.nimbusds.jose.jwk.JWK;
import com.nimbusds.jose.jwk.JWKSet;
import com.nimbusds.jose.jwk.KeyUse;
import com.nimbusds.jose.jwk.source.JWKSetCache;
import com.nimbusds.jose.jwk.source.JWKSource;
import com.nimbusds.jose.jwk.source.RemoteJWKSet;
Expand All @@ -49,7 +38,8 @@
import com.nimbusds.jwt.proc.ConfigurableJWTProcessor;
import com.nimbusds.jwt.proc.DefaultJWTProcessor;
import com.nimbusds.jwt.proc.JWTProcessor;

import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.springframework.cache.Cache;
import org.springframework.core.convert.converter.Converter;
import org.springframework.http.HttpHeaders;
Expand All @@ -65,6 +55,22 @@
import org.springframework.web.client.RestOperations;
import org.springframework.web.client.RestTemplate;

import javax.crypto.SecretKey;
import java.io.IOException;
import java.net.MalformedURLException;
import java.net.URL;
import java.security.interfaces.RSAPublicKey;
import java.text.ParseException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashSet;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.function.Consumer;

/**
* A low-level Nimbus implementation of {@link JwtDecoder} which takes a raw Nimbus configuration.
*
Expand Down Expand Up @@ -215,6 +221,9 @@ public static SecretKeyJwtDecoderBuilder withSecretKey(SecretKey secretKey) {
* <a target="_blank" href="https://tools.ietf.org/html/rfc7517#section-5">JWK Set</a> uri.
*/
public static final class JwkSetUriJwtDecoderBuilder {

private static final Log log = LogFactory.getLog(JwkSetUriJwtDecoderBuilder.class);

private String jwkSetUri;
private Set<SignatureAlgorithm> signatureAlgorithms = new HashSet<>();
private RestOperations restOperations = new RestTemplate();
Expand Down Expand Up @@ -283,16 +292,58 @@ public JwkSetUriJwtDecoderBuilder cache(Cache cache) {
}

JWSKeySelector<SecurityContext> jwsKeySelector(JWKSource<SecurityContext> jwkSource) {
if (this.signatureAlgorithms.isEmpty()) {
return new JWSVerificationKeySelector<>(JWSAlgorithm.RS256, jwkSource);
Set<SignatureAlgorithm> algorithms = new HashSet<>();
if (!this.signatureAlgorithms.isEmpty()) {
algorithms.addAll(this.signatureAlgorithms);
} else {
Set<JWSAlgorithm> jwsAlgorithms = new HashSet<>();
for (SignatureAlgorithm signatureAlgorithm : this.signatureAlgorithms) {
JWSAlgorithm jwsAlgorithm = JWSAlgorithm.parse(signatureAlgorithm.getName());
jwsAlgorithms.add(jwsAlgorithm);
algorithms.addAll(fetchSignatureAlgorithms());
}

if (algorithms.isEmpty()) {
algorithms.add(SignatureAlgorithm.RS256);
}

Set<JWSAlgorithm> jwsAlgorithms = new HashSet<>();
for (SignatureAlgorithm signatureAlgorithm : algorithms) {
jwsAlgorithms.add(JWSAlgorithm.parse(signatureAlgorithm.getName()));
}

return new JWSVerificationKeySelector<>(jwsAlgorithms, jwkSource);
}

private Set<SignatureAlgorithm> fetchSignatureAlgorithms() {
try {
return parseAlgorithms(JWKSet.load(toURL(jwkSetUri), 5000, 5000, 0));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's please use the JWKSource to retrieve JWKs. That will allow you to use a JWKMatcher that removes some of the validation you are doing in parseAlgorithms.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was my initial approach, however for reasons I can't explain (hopefully somebody smarter than me can explain), importing the JWKMatcher.Builder() class causes a torrent of unit test failures on seemingly unrelated unit tests such as "ClassDefNotFound" for classes such as "LdapServerBeanDefinitionParserTests". If I can get that issue resolved i would be more than happy to replace this with the JWKSource.

Another issue with using JWKSource in the NimbusReactiveJwkDecoder is that (as far as i can tell) JWKSecurityContextJWKSet is passed in during invocation of the processor() method, calling the get() method on that class at this stage will always yield no results. I do not have a context to pass into it that contains JWKs.

} catch (Exception ex) {
throw new IllegalArgumentException("Failed to load Signature Algorithms from remote JWK source.", ex);
}
}

private Set<SignatureAlgorithm> parseAlgorithms(JWKSet jwkSet) {
if (jwkSet == null) {
throw new IllegalArgumentException(String.format("No JWKs received from %s", jwkSetUri));
}

List<JWK> jwks = new ArrayList<>();
for (JWK jwk : jwkSet.getKeys()) {
KeyUse keyUse = jwk.getKeyUse();
if (keyUse != null && keyUse.equals(KeyUse.SIGNATURE)) {
jwks.add(jwk);
}
}

Set<SignatureAlgorithm> algorithms = new HashSet<>();
for (JWK jwk : jwks) {
Algorithm algorithm = jwk.getAlgorithm();
if (algorithm != null) {
SignatureAlgorithm signatureAlgorithm = SignatureAlgorithm.from(algorithm.getName());
if (signatureAlgorithm != null) {
algorithms.add(signatureAlgorithm);
}
}
return new JWSVerificationKeySelector<>(jwsAlgorithms, jwkSource);
}

return algorithms;
}

JWKSource<SecurityContext> jwkSource(ResourceRetriever jwkSetRetriever) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,48 +15,35 @@
*/
package org.springframework.security.oauth2.jwt;

import java.security.interfaces.RSAPublicKey;
import java.util.Collections;
import java.util.HashSet;
import java.util.LinkedHashMap;
import java.util.Map;
import java.util.Set;
import java.util.function.Consumer;
import java.util.function.Function;
import javax.crypto.SecretKey;

import com.nimbusds.jose.Header;
import com.nimbusds.jose.JOSEException;
import com.nimbusds.jose.JWSAlgorithm;
import com.nimbusds.jose.JWSHeader;
import com.nimbusds.jose.jwk.JWK;
import com.nimbusds.jose.jwk.JWKMatcher;
import com.nimbusds.jose.jwk.JWKSelector;
import com.nimbusds.jose.*;
import com.nimbusds.jose.jwk.*;
import com.nimbusds.jose.jwk.source.JWKSecurityContextJWKSet;
import com.nimbusds.jose.jwk.source.JWKSource;
import com.nimbusds.jose.proc.BadJOSEException;
import com.nimbusds.jose.proc.JWKSecurityContext;
import com.nimbusds.jose.proc.JWSKeySelector;
import com.nimbusds.jose.proc.JWSVerificationKeySelector;
import com.nimbusds.jose.proc.SecurityContext;
import com.nimbusds.jwt.JWT;
import com.nimbusds.jwt.JWTClaimsSet;
import com.nimbusds.jwt.JWTParser;
import com.nimbusds.jwt.PlainJWT;
import com.nimbusds.jwt.SignedJWT;
import com.nimbusds.jose.proc.*;
import com.nimbusds.jwt.*;
import com.nimbusds.jwt.proc.DefaultJWTProcessor;
import com.nimbusds.jwt.proc.JWTProcessor;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;

import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.springframework.core.convert.converter.Converter;
import org.springframework.security.oauth2.core.OAuth2TokenValidator;
import org.springframework.security.oauth2.core.OAuth2TokenValidatorResult;
import org.springframework.security.oauth2.jose.jws.JwsAlgorithm;
import org.springframework.security.oauth2.jose.jws.MacAlgorithm;
import org.springframework.security.oauth2.jose.jws.SignatureAlgorithm;
import org.springframework.util.Assert;
import org.springframework.util.StringUtils;
import org.springframework.web.reactive.function.client.WebClient;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;

import javax.crypto.SecretKey;
import java.net.MalformedURLException;
import java.net.URL;
import java.security.interfaces.RSAPublicKey;
import java.util.*;
import java.util.function.Consumer;
import java.util.function.Function;

/**
* An implementation of a {@link ReactiveJwtDecoder} that &quot;decodes&quot; a
Expand Down Expand Up @@ -242,6 +229,9 @@ public static JwkSourceReactiveJwtDecoderBuilder withJwkSource(Function<SignedJW
* @since 5.2
*/
public static final class JwkSetUriReactiveJwtDecoderBuilder {

private static final Log log = LogFactory.getLog(JwkSetUriReactiveJwtDecoderBuilder.class);

private final String jwkSetUri;
private Set<SignatureAlgorithm> signatureAlgorithms = new HashSet<>();
private WebClient webClient = WebClient.create();
Expand Down Expand Up @@ -304,16 +294,61 @@ public NimbusReactiveJwtDecoder build() {
}

JWSKeySelector<JWKSecurityContext> jwsKeySelector(JWKSource<JWKSecurityContext> jwkSource) {
if (this.signatureAlgorithms.isEmpty()) {
return new JWSVerificationKeySelector<>(JWSAlgorithm.RS256, jwkSource);
Set<SignatureAlgorithm> algorithms = new HashSet<>();
if (!this.signatureAlgorithms.isEmpty()) {
algorithms.addAll(this.signatureAlgorithms);
} else {
Set<JWSAlgorithm> jwsAlgorithms = new HashSet<>();
for (SignatureAlgorithm signatureAlgorithm : this.signatureAlgorithms) {
JWSAlgorithm jwsAlgorithm = JWSAlgorithm.parse(signatureAlgorithm.getName());
jwsAlgorithms.add(jwsAlgorithm);
algorithms.addAll(fetchSignatureAlgorithms());
}

if (algorithms.isEmpty()) {
algorithms.add(SignatureAlgorithm.RS256);
}

Set<JWSAlgorithm> jwsAlgorithms = new HashSet<>();
for (SignatureAlgorithm signatureAlgorithm : algorithms) {
jwsAlgorithms.add(JWSAlgorithm.parse(signatureAlgorithm.getName()));
}

return new JWSVerificationKeySelector<>(jwsAlgorithms, jwkSource);
}

private Set<SignatureAlgorithm> fetchSignatureAlgorithms() {
if (StringUtils.isEmpty(jwkSetUri)) {
return Collections.emptySet();
}
try {
return parseAlgorithms(JWKSet.load(toURL(jwkSetUri), 5000, 5000, 0));
} catch (Exception ex) {
throw new IllegalArgumentException("Failed to load Signature Algorithms from remote JWK source.", ex);
}
}

private Set<SignatureAlgorithm> parseAlgorithms(JWKSet jwkSet) {
if (jwkSet == null) {
throw new IllegalArgumentException(String.format("No JWKs received from %s", jwkSetUri));
}

List<JWK> jwks = new ArrayList<>();
for (JWK jwk : jwkSet.getKeys()) {
KeyUse keyUse = jwk.getKeyUse();
if (keyUse != null && keyUse.equals(KeyUse.SIGNATURE)) {
jwks.add(jwk);
}
}

Set<SignatureAlgorithm> algorithms = new HashSet<>();
for (JWK jwk : jwks) {
Algorithm algorithm = jwk.getAlgorithm();
if (algorithm != null) {
SignatureAlgorithm signatureAlgorithm = SignatureAlgorithm.from(algorithm.getName());
if (signatureAlgorithm != null) {
algorithms.add(signatureAlgorithm);
}
}
return new JWSVerificationKeySelector<>(jwsAlgorithms, jwkSource);
}

return algorithms;
}

Converter<JWT, Mono<JWTClaimsSet>> processor() {
Expand Down Expand Up @@ -350,6 +385,14 @@ private JWKSelector createSelector(Function<JWSAlgorithm, Boolean> expectedJwsAl

return new JWKSelector(JWKMatcher.forJWSHeader(jwsHeader));
}

private static URL toURL(String url) {
try {
return new URL(url);
} catch (MalformedURLException ex) {
throw new IllegalArgumentException("Invalid JWK Set URL \"" + url + "\" : " + ex.getMessage(), ex);
}
}
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -265,6 +265,7 @@ private void prepareConfigurationResponse() {
private void prepareConfigurationResponse(String body) {
this.server.enqueue(response(body));
this.server.enqueue(response(JWK_SET));
this.server.enqueue(response(JWK_SET));
}

private void prepareConfigurationResponseOidc() {
Expand Down
Loading