diff --git a/tokens/src/main/java/io/scalecube/security/tokens/jwt/JwksKeyProvider.java b/tokens/src/main/java/io/scalecube/security/tokens/jwt/JwksKeyProvider.java index a3edcc6..cb4074e 100644 --- a/tokens/src/main/java/io/scalecube/security/tokens/jwt/JwksKeyProvider.java +++ b/tokens/src/main/java/io/scalecube/security/tokens/jwt/JwksKeyProvider.java @@ -1,5 +1,7 @@ package io.scalecube.security.tokens.jwt; +import static io.scalecube.security.tokens.jwt.Utils.toRsaPublicKey; + import com.fasterxml.jackson.annotation.JsonAutoDetect; import com.fasterxml.jackson.annotation.JsonInclude; import com.fasterxml.jackson.annotation.PropertyAccessor; @@ -24,9 +26,9 @@ public final class JwksKeyProvider implements KeyProvider { private static final Logger LOGGER = LoggerFactory.getLogger(JwksKeyProvider.class); - private final Scheduler scheduler = Schedulers.newSingle("jwks-key-provider", true); + private static final ObjectMapper OBJECT_MAPPER = newObjectMapper(); - private final ObjectMapper mapper; + private final Scheduler scheduler; private final String jwksUri; private final long connectTimeoutMillis; private final long readTimeoutMillis; @@ -37,22 +39,21 @@ public final class JwksKeyProvider implements KeyProvider { * @param jwksUri jwksUri */ public JwksKeyProvider(String jwksUri) { - this.jwksUri = jwksUri; - this.mapper = initMapper(); - this.connectTimeoutMillis = Duration.ofSeconds(10).toMillis(); - this.readTimeoutMillis = Duration.ofSeconds(10).toMillis(); + this(jwksUri, newScheduler(), Duration.ofSeconds(10), Duration.ofSeconds(10)); } /** * Constructor. * * @param jwksUri jwksUri + * @param scheduler scheduler * @param connectTimeout connectTimeout * @param readTimeout readTimeout */ - public JwksKeyProvider(String jwksUri, Duration connectTimeout, Duration readTimeout) { + public JwksKeyProvider( + String jwksUri, Scheduler scheduler, Duration connectTimeout, Duration readTimeout) { this.jwksUri = jwksUri; - this.mapper = initMapper(); + this.scheduler = scheduler; this.connectTimeoutMillis = connectTimeout.toMillis(); this.readTimeoutMillis = readTimeout.toMillis(); } @@ -87,7 +88,7 @@ private Mono callJwksUri() { private JwkInfoList toKeyList(InputStream stream) { try (InputStream inputStream = new BufferedInputStream(stream)) { - return mapper.readValue(inputStream, JwkInfoList.class); + return OBJECT_MAPPER.readValue(inputStream, JwkInfoList.class); } catch (IOException e) { LOGGER.error("[toKeyList] Exception occurred: {}", e.toString()); throw new KeyProviderException(e); @@ -98,10 +99,10 @@ private Optional findRsaKey(JwkInfoList list, String kid) { return list.keys().stream() .filter(k -> kid.equals(k.kid())) .findFirst() - .map(info -> Utils.getRsaPublicKey(info.modulus(), info.exponent())); + .map(info -> toRsaPublicKey(info.modulus(), info.exponent())); } - private static ObjectMapper initMapper() { + private static ObjectMapper newObjectMapper() { ObjectMapper mapper = new ObjectMapper(); mapper.configure(DeserializationFeature.FAIL_ON_UNKNOWN_PROPERTIES, false); mapper.configure(SerializationFeature.FAIL_ON_EMPTY_BEANS, false); @@ -111,4 +112,8 @@ private static ObjectMapper initMapper() { mapper.setSerializationInclusion(JsonInclude.Include.NON_NULL); return mapper; } + + private static Scheduler newScheduler() { + return Schedulers.newElastic("jwks-key-provider", 60, true); + } } diff --git a/tokens/src/main/java/io/scalecube/security/tokens/jwt/JwtTokenResolverImpl.java b/tokens/src/main/java/io/scalecube/security/tokens/jwt/JwtTokenResolverImpl.java index c344fb0..a0ca216 100644 --- a/tokens/src/main/java/io/scalecube/security/tokens/jwt/JwtTokenResolverImpl.java +++ b/tokens/src/main/java/io/scalecube/security/tokens/jwt/JwtTokenResolverImpl.java @@ -1,6 +1,8 @@ package io.scalecube.security.tokens.jwt; +import io.scalecube.security.tokens.jwt.jsonwebtoken.JsonwebtokenParserFactory; import java.security.Key; +import java.time.Duration; import java.util.Map; import java.util.Objects; import java.util.concurrent.ConcurrentHashMap; @@ -18,8 +20,8 @@ public final class JwtTokenResolverImpl implements JwtTokenResolver { private final KeyProvider keyProvider; private final JwtTokenParserFactory tokenParserFactory; - private final int cleanupIntervalSec; private final Scheduler scheduler; + private final Duration cleanupInterval; private final Map> keyResolutions = new ConcurrentHashMap<>(); @@ -27,10 +29,9 @@ public final class JwtTokenResolverImpl implements JwtTokenResolver { * Constructor. * * @param keyProvider key provider - * @param tokenParserFactory token parser factoty */ - public JwtTokenResolverImpl(KeyProvider keyProvider, JwtTokenParserFactory tokenParserFactory) { - this(keyProvider, tokenParserFactory, 3600, Schedulers.newSingle("caching-key-provider", true)); + public JwtTokenResolverImpl(KeyProvider keyProvider) { + this(keyProvider, new JsonwebtokenParserFactory(), newScheduler(), Duration.ofSeconds(60)); } /** @@ -38,17 +39,17 @@ public JwtTokenResolverImpl(KeyProvider keyProvider, JwtTokenParserFactory token * * @param keyProvider key provider * @param tokenParserFactory token parser factoty - * @param cleanupIntervalSec cleanup interval (in sec) for resolved cached keys * @param scheduler cleanup scheduler + * @param cleanupInterval cleanup interval for resolved cached keys */ public JwtTokenResolverImpl( KeyProvider keyProvider, JwtTokenParserFactory tokenParserFactory, - int cleanupIntervalSec, - Scheduler scheduler) { + Scheduler scheduler, + Duration cleanupInterval) { this.keyProvider = keyProvider; this.tokenParserFactory = tokenParserFactory; - this.cleanupIntervalSec = cleanupIntervalSec; + this.cleanupInterval = cleanupInterval; this.scheduler = scheduler; } @@ -107,7 +108,7 @@ private Mono findKey(String kid, AtomicReference> computedValueHo private void scheduleCleanup(String kid, AtomicReference> computedValueHolder) { scheduler.schedule( - () -> cleanup(kid, computedValueHolder), cleanupIntervalSec, TimeUnit.SECONDS); + () -> cleanup(kid, computedValueHolder), cleanupInterval.toMillis(), TimeUnit.MILLISECONDS); } private void cleanup(String kid, AtomicReference> computedValueHolder) { @@ -115,4 +116,8 @@ private void cleanup(String kid, AtomicReference> computedValueHolder) keyResolutions.remove(kid, computedValueHolder.get()); } } + + private static Scheduler newScheduler() { + return Schedulers.newElastic("token-resolver-cleaner", 60, true); + } } diff --git a/tokens/src/main/java/io/scalecube/security/tokens/jwt/Utils.java b/tokens/src/main/java/io/scalecube/security/tokens/jwt/Utils.java index c33a2da..0f0ca3d 100644 --- a/tokens/src/main/java/io/scalecube/security/tokens/jwt/Utils.java +++ b/tokens/src/main/java/io/scalecube/security/tokens/jwt/Utils.java @@ -22,7 +22,7 @@ private Utils() { * @param e exponent (b64 url encoded) * @return RSA public key instance */ - public static Key getRsaPublicKey(String n, String e) { + public static Key toRsaPublicKey(String n, String e) { Decoder b64Decoder = Base64.getUrlDecoder(); BigInteger modulus = new BigInteger(1, b64Decoder.decode(n)); BigInteger exponent = new BigInteger(1, b64Decoder.decode(e)); diff --git a/tokens/src/test/java/io/scalecube/security/tokens/jwt/BaseTest.java b/tokens/src/test/java/io/scalecube/security/tokens/jwt/BaseTest.java deleted file mode 100644 index 7058ece..0000000 --- a/tokens/src/test/java/io/scalecube/security/tokens/jwt/BaseTest.java +++ /dev/null @@ -1,49 +0,0 @@ -package io.scalecube.security.tokens.jwt; - -import java.lang.reflect.Method; -import java.time.Duration; -import org.junit.jupiter.api.AfterAll; -import org.junit.jupiter.api.AfterEach; -import org.junit.jupiter.api.BeforeAll; -import org.junit.jupiter.api.BeforeEach; -import org.junit.jupiter.api.TestInfo; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; -import reactor.test.StepVerifier; - -public class BaseTest { - - protected static final Logger LOGGER = LoggerFactory.getLogger(BaseTest.class); - - public static final Duration TIMEOUT = Duration.ofSeconds(10); - - @BeforeAll - public static void init() { - StepVerifier.setDefaultTimeout(TIMEOUT); - } - - @AfterAll - public static void reset() { - StepVerifier.resetDefaultTimeout(); - } - - @BeforeEach - public final void baseSetUp(TestInfo testInfo) { - LOGGER.info( - "***** Test started : " - + testInfo.getTestClass().map(Class::getSimpleName).orElse("") - + "." - + testInfo.getTestMethod().map(Method::getName).orElse("") - + " *****"); - } - - @AfterEach - public final void baseTearDown(TestInfo testInfo) { - LOGGER.info( - "***** Test finished : " - + testInfo.getTestClass().map(Class::getSimpleName).orElse("") - + "." - + testInfo.getTestMethod().map(Method::getName).orElse("") - + " *****"); - } -} diff --git a/tokens/src/test/java/io/scalecube/security/tokens/jwt/JwtTokenResolverTests.java b/tokens/src/test/java/io/scalecube/security/tokens/jwt/JwtTokenResolverTests.java index 0a7129d..c97c240 100644 --- a/tokens/src/test/java/io/scalecube/security/tokens/jwt/JwtTokenResolverTests.java +++ b/tokens/src/test/java/io/scalecube/security/tokens/jwt/JwtTokenResolverTests.java @@ -1,7 +1,9 @@ package io.scalecube.security.tokens.jwt; -import java.io.IOException; +import static io.scalecube.security.tokens.jwt.Utils.toRsaPublicKey; + import java.security.Key; +import java.time.Duration; import java.util.Collections; import java.util.Map; import java.util.Properties; @@ -10,13 +12,14 @@ import org.mockito.Mockito; import reactor.core.publisher.Mono; import reactor.test.StepVerifier; +import reactor.test.scheduler.VirtualTimeScheduler; -class JwtTokenResolverTests extends BaseTest { +class JwtTokenResolverTests { private static final Map BODY = Collections.singletonMap("aud", "aud"); @Test - void testTokenResolver() throws IOException { + void testTokenResolver() throws Exception { TokenWithKey tokenWithKey = new TokenWithKey("token-and-pubkey.properties"); JwtTokenParser tokenParser = Mockito.mock(JwtTokenParser.class); @@ -32,7 +35,9 @@ void testTokenResolver() throws IOException { KeyProvider keyProvider = Mockito.mock(KeyProvider.class); Mockito.when(keyProvider.findKey(tokenWithKey.kid)).thenReturn(Mono.just(tokenWithKey.key)); - JwtTokenResolverImpl tokenResolver = new JwtTokenResolverImpl(keyProvider, tokenParserFactory); + JwtTokenResolverImpl tokenResolver = + new JwtTokenResolverImpl( + keyProvider, tokenParserFactory, VirtualTimeScheduler.create(), Duration.ofSeconds(3)); // N times call resolve StepVerifier.create(tokenResolver.resolve(tokenWithKey.token).repeat(3)) @@ -45,7 +50,7 @@ void testTokenResolver() throws IOException { } @Test - void testTokenResolverWithRotatingKey() throws IOException { + void testTokenResolverWithRotatingKey() throws Exception { TokenWithKey tokenWithKey = new TokenWithKey("token-and-pubkey.properties"); TokenWithKey tokenWithKeyAfterRotation = new TokenWithKey("token-and-pubkey.after-rotation.properties"); @@ -70,7 +75,9 @@ void testTokenResolverWithRotatingKey() throws IOException { Mockito.when(keyProvider.findKey(tokenWithKeyAfterRotation.kid)) .thenReturn(Mono.just(tokenWithKeyAfterRotation.key)); - JwtTokenResolverImpl tokenResolver = new JwtTokenResolverImpl(keyProvider, tokenParserFactory); + JwtTokenResolverImpl tokenResolver = + new JwtTokenResolverImpl( + keyProvider, tokenParserFactory, VirtualTimeScheduler.create(), Duration.ofSeconds(3)); // Call normal token first StepVerifier.create(tokenResolver.resolve(tokenWithKey.token)) @@ -90,7 +97,7 @@ void testTokenResolverWithRotatingKey() throws IOException { } @Test - void testTokenResolverWithWrongKey() throws IOException { + void testTokenResolverWithWrongKey() throws Exception { TokenWithKey tokenWithWrongKey = new TokenWithKey("token-and-wrong-pubkey.properties"); JwtTokenParser tokenParser = Mockito.mock(JwtTokenParser.class); @@ -106,7 +113,9 @@ void testTokenResolverWithWrongKey() throws IOException { Mockito.when(keyProvider.findKey(tokenWithWrongKey.kid)) .thenReturn(Mono.just(tokenWithWrongKey.key)); - JwtTokenResolverImpl tokenResolver = new JwtTokenResolverImpl(keyProvider, tokenParserFactory); + JwtTokenResolverImpl tokenResolver = + new JwtTokenResolverImpl( + keyProvider, tokenParserFactory, VirtualTimeScheduler.create(), Duration.ofSeconds(3)); // Must fail (retry N times) StepVerifier.create(tokenResolver.resolve(tokenWithWrongKey.token).retry(1)) @@ -118,7 +127,7 @@ void testTokenResolverWithWrongKey() throws IOException { } @Test - void testTokenResolverWhenKeyProviderFailing() throws IOException { + void testTokenResolverWhenKeyProviderFailing() throws Exception { TokenWithKey tokenWithKey = new TokenWithKey("token-and-pubkey.properties"); JwtTokenParser tokenParser = Mockito.mock(JwtTokenParser.class); @@ -134,7 +143,9 @@ void testTokenResolverWhenKeyProviderFailing() throws IOException { KeyProvider keyProvider = Mockito.mock(KeyProvider.class); Mockito.when(keyProvider.findKey(tokenWithKey.kid)).thenThrow(RuntimeException.class); - JwtTokenResolverImpl tokenResolver = new JwtTokenResolverImpl(keyProvider, tokenParserFactory); + JwtTokenResolverImpl tokenResolver = + new JwtTokenResolverImpl( + keyProvider, tokenParserFactory, VirtualTimeScheduler.create(), Duration.ofSeconds(3)); // Must fail with "hola" (retry N times) StepVerifier.create(tokenResolver.resolve(tokenWithKey.token).retry(1)).expectError().verify(); @@ -149,13 +160,13 @@ static class TokenWithKey { final Key key; final String kid; - TokenWithKey(String s) throws IOException { + TokenWithKey(String s) throws Exception { ClassLoader classLoader = Thread.currentThread().getContextClassLoader(); Properties props = new Properties(); props.load(classLoader.getResourceAsStream(s)); this.token = props.getProperty("token"); this.kid = props.getProperty("kid"); - this.key = Utils.getRsaPublicKey(props.getProperty("n"), props.getProperty("e")); + this.key = toRsaPublicKey(props.getProperty("n"), props.getProperty("e")); } } } diff --git a/tokens/src/test/java/io/scalecube/security/tokens/jwt/JwksKeyProviderTests.java b/tokens/src/test/java/io/scalecube/security/tokens/jwt/VaultEnvironment.java similarity index 55% rename from tokens/src/test/java/io/scalecube/security/tokens/jwt/JwksKeyProviderTests.java rename to tokens/src/test/java/io/scalecube/security/tokens/jwt/VaultEnvironment.java index 94b2105..f2b9dab 100644 --- a/tokens/src/test/java/io/scalecube/security/tokens/jwt/JwksKeyProviderTests.java +++ b/tokens/src/test/java/io/scalecube/security/tokens/jwt/VaultEnvironment.java @@ -4,23 +4,14 @@ import com.bettercloud.vault.rest.Rest; import com.bettercloud.vault.rest.RestException; import com.bettercloud.vault.rest.RestResponse; -import io.jsonwebtoken.Claims; -import io.jsonwebtoken.Header; -import io.jsonwebtoken.Jwt; -import io.jsonwebtoken.JwtParserBuilder; -import io.jsonwebtoken.Jwts; import java.io.IOException; import java.util.UUID; -import org.junit.jupiter.api.AfterAll; -import org.junit.jupiter.api.BeforeAll; -import org.junit.jupiter.api.Test; import org.testcontainers.containers.Container.ExecResult; import org.testcontainers.containers.GenericContainer; import org.testcontainers.containers.wait.strategy.LogMessageWaitStrategy; import org.testcontainers.vault.VaultContainer; -import reactor.test.StepVerifier; -class JwksKeyProviderTests extends BaseTest { +public class VaultEnvironment { private static final String VAULT_TOKEN = "test"; private static final String VAULT_TOKEN_HEADER = "X-Vault-Token"; @@ -32,58 +23,19 @@ class JwksKeyProviderTests extends BaseTest { private static String vaultAddr; - @BeforeAll - static void beforeAll() { + public static void start() { VAULT_CONTAINER.start(); vaultAddr = "http://localhost:" + VAULT_CONTAINER.getMappedPort(8200); } - @AfterAll - static void afterAll() { + public static void stop() { VAULT_CONTAINER.stop(); } - @Test - void testJwksKeysRetrieval() throws RestException, IOException, InterruptedException { - String keyName = createIdentityKey(vaultAddr); // oidc/key - String roleName = createIdentityRole(vaultAddr, keyName); // oidc/role - createIdentityTokenPolicy(roleName); // write policy policyfile.hcl - String clientToken = createEntity(roleName); // onboard some entity with policy line above - String token = generateIdentityToken(clientToken, roleName); // oidc/token - String kid = getKid(token); - - JwksKeyProvider keyProvider = new JwksKeyProvider(jwksUri(vaultAddr)); - - StepVerifier.create(keyProvider.findKey(kid)).expectNextCount(1).expectComplete().verify(); - } - - @Test - void testJwksKeysRetrievalKeyNotFound() { - JwksKeyProvider keyProvider = new JwksKeyProvider(jwksUri(vaultAddr)); - - StepVerifier.create(keyProvider.findKey(UUID.randomUUID().toString())) - .expectErrorMatches( - th -> th.getMessage() != null && th.getMessage().contains("Key was not found")) - .verify(); - } - - private static String getKid(String token) { - String justClaims = token.substring(0, token.lastIndexOf(".") + 1); - JwtParserBuilder parserBuilder = Jwts.parserBuilder(); - //noinspection rawtypes - Jwt claims = parserBuilder.build().parseClaimsJwt(justClaims); - //noinspection rawtypes - Header header = claims.getHeader(); - return (String) header.get("kid"); - } - - private static String generateIdentityToken(String clientToken, String roleName) + public static String generateIdentityToken(String clientToken, String roleName) throws RestException { RestResponse restResponse = - new Rest() - .header(VAULT_TOKEN_HEADER, clientToken) - .url(oidcToken(vaultAddr, roleName)) - .get(); + new Rest().header(VAULT_TOKEN_HEADER, clientToken).url(oidcToken(roleName)).get(); int status = restResponse.getStatus(); if (status != 200 && status != 204) { @@ -99,11 +51,11 @@ private static String generateIdentityToken(String clientToken, String roleName) .asString(); } - private static void createIdentityTokenPolicy(String roleName) throws RestException { + public static void createIdentityTokenPolicy(String roleName) throws RestException { int status = new Rest() .header(VAULT_TOKEN_HEADER, VAULT_TOKEN) - .url(policiesAclUri(vaultAddr, roleName)) + .url(policiesAclUri(roleName)) .body( ("{\"policy\":\"path \\\"identity/oidc/token/" + roleName @@ -118,7 +70,7 @@ private static void createIdentityTokenPolicy(String roleName) throws RestExcept } } - private static String createEntity(final String roleName) + public static String createEntity(final String roleName) throws IOException, InterruptedException { checkSuccess( @@ -142,29 +94,24 @@ private static String createEntity(final String roleName) .asString(); } - private static void checkSuccess(int exitCode) { + public static void checkSuccess(int exitCode) { if (exitCode != 0) { throw new IllegalStateException("Exited with error: " + exitCode); } } - private static String createIdentityKey(String vaultAddr) throws RestException { - return createIdentityKey(vaultAddr, "1m", "1m"); - } - - private static String createIdentityKey( - String vaultAddr, String rotationPeriod, String verificationTtl) throws RestException { + public static String createIdentityKey() throws RestException { String keyName = UUID.randomUUID().toString(); int status = new Rest() .header(VAULT_TOKEN_HEADER, VAULT_TOKEN) - .url(oidcKeyUrl(vaultAddr, keyName)) + .url(oidcKeyUrl(keyName)) .body( ("{\"rotation_period\":\"" - + rotationPeriod + + "1m" + "\", " + "\"verification_ttl\": \"" - + verificationTtl + + "1m" + "\", " + "\"allowed_client_ids\": \"*\", " + "\"algorithm\": \"RS256\"}") @@ -178,18 +125,13 @@ private static String createIdentityKey( return keyName; } - private static String createIdentityRole(String vaultAddr, String keyName) throws RestException { - return createIdentityRole(vaultAddr, keyName, "1h"); - } - - private static String createIdentityRole(String vaultAddr, String keyName, String ttl) - throws RestException { + public static String createIdentityRole(String keyName) throws RestException { String roleName = UUID.randomUUID().toString(); int status = new Rest() .header(VAULT_TOKEN_HEADER, VAULT_TOKEN) - .url(oidcRoleUrl(vaultAddr, roleName)) - .body(("{\"key\":\"" + keyName + "\",\"ttl\": \"" + ttl + "\"}").getBytes()) + .url(oidcRoleUrl(roleName)) + .body(("{\"key\":\"" + keyName + "\",\"ttl\": \"" + "1h" + "\"}").getBytes()) .post() .getStatus(); @@ -199,23 +141,23 @@ private static String createIdentityRole(String vaultAddr, String keyName, Strin return roleName; } - private static String oidcKeyUrl(String vaultAddr, String keyName) { + public static String oidcKeyUrl(String keyName) { return vaultAddr + "/v1/identity/oidc/key/" + keyName; } - private static String oidcRoleUrl(String vaultAddr, String roleName) { + public static String oidcRoleUrl(String roleName) { return vaultAddr + "/v1/identity/oidc/role/" + roleName; } - private static String oidcToken(String vaultAddr, String roleName) { + public static String oidcToken(String roleName) { return vaultAddr + "/v1/identity/oidc/token/" + roleName; } - private static String jwksUri(String vaultAddr) { + public static String jwksUri() { return vaultAddr + "/v1/identity/oidc/.well-known/keys"; } - private static String policiesAclUri(String vaultAddr, String roleName) { + public static String policiesAclUri(String roleName) { return vaultAddr + "/v1/sys/policies/acl/" + roleName; } } diff --git a/tokens/src/test/java/io/scalecube/security/tokens/jwt/VaultJwksKeyProviderTests.java b/tokens/src/test/java/io/scalecube/security/tokens/jwt/VaultJwksKeyProviderTests.java new file mode 100644 index 0000000..4480613 --- /dev/null +++ b/tokens/src/test/java/io/scalecube/security/tokens/jwt/VaultJwksKeyProviderTests.java @@ -0,0 +1,101 @@ +package io.scalecube.security.tokens.jwt; + +import static io.scalecube.security.tokens.jwt.VaultEnvironment.createEntity; +import static io.scalecube.security.tokens.jwt.VaultEnvironment.createIdentityKey; +import static io.scalecube.security.tokens.jwt.VaultEnvironment.createIdentityRole; +import static io.scalecube.security.tokens.jwt.VaultEnvironment.createIdentityTokenPolicy; +import static io.scalecube.security.tokens.jwt.VaultEnvironment.generateIdentityToken; +import static io.scalecube.security.tokens.jwt.VaultEnvironment.jwksUri; +import static org.hamcrest.CoreMatchers.startsWith; + +import io.jsonwebtoken.Claims; +import io.jsonwebtoken.Header; +import io.jsonwebtoken.Jwt; +import io.jsonwebtoken.JwtParserBuilder; +import io.jsonwebtoken.Jwts; +import java.time.Duration; +import java.util.UUID; +import org.junit.Assert; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.DisplayName; +import org.junit.jupiter.api.Test; +import reactor.test.StepVerifier; + +class VaultJwksKeyProviderTests { + + private static final Duration TIMEOUT = Duration.ofSeconds(3); + + @BeforeEach + void setup() { + VaultEnvironment.start(); + } + + @AfterEach + void cleanup() { + VaultEnvironment.stop(); + } + + @Test + @DisplayName("Find key successfully") + void testFindKey() throws Exception { + String keyName = createIdentityKey(); // oidc/key + String roleName = createIdentityRole(keyName); // oidc/role + createIdentityTokenPolicy(roleName); // write policy policyfile.hcl + String clientToken = createEntity(roleName); // onboard some entity with policy line above + String token = generateIdentityToken(clientToken, roleName); // oidc/token + String kid = getKid(token); + + JwksKeyProvider keyProvider = new JwksKeyProvider(jwksUri()); + + StepVerifier.create(keyProvider.findKey(kid)) + .expectNextCount(1) + .expectComplete() + .verify(TIMEOUT); + } + + @Test + @DisplayName("Fails to find non-existent key") + void testFindNonExistentKey() throws Exception { + String keyName = createIdentityKey(); // oidc/key + String roleName = createIdentityRole(keyName); // oidc/role + createIdentityTokenPolicy(roleName); // write policy policyfile.hcl + String clientToken = createEntity(roleName); // onboard some entity with policy line above + generateIdentityToken(clientToken, roleName); // oidc/token + + JwksKeyProvider keyProvider = new JwksKeyProvider(jwksUri()); + + StepVerifier.create(keyProvider.findKey(UUID.randomUUID().toString())) + .expectErrorSatisfies( + throwable -> { + Assertions.assertEquals(throwable.getClass(), KeyProviderException.class); + Assert.assertThat(throwable.getMessage(), startsWith("Key was not found")); + }) + .verify(TIMEOUT); + } + + @Test + @DisplayName("Fails to find key on empty environment") + void testKeyNotFoundOnEmptyEnvironment() { + JwksKeyProvider keyProvider = new JwksKeyProvider(jwksUri()); + + StepVerifier.create(keyProvider.findKey(UUID.randomUUID().toString())) + .expectErrorSatisfies( + throwable -> { + Assertions.assertEquals(throwable.getClass(), KeyProviderException.class); + Assert.assertThat(throwable.getMessage(), startsWith("Key was not found")); + }) + .verify(TIMEOUT); + } + + private static String getKid(String token) { + String justClaims = token.substring(0, token.lastIndexOf(".") + 1); + JwtParserBuilder parserBuilder = Jwts.parserBuilder(); + //noinspection rawtypes + Jwt claims = parserBuilder.build().parseClaimsJwt(justClaims); + //noinspection rawtypes + Header header = claims.getHeader(); + return (String) header.get("kid"); + } +}