From ec5ab8b087e1c484f7d50c0da6fd00e6ccb41026 Mon Sep 17 00:00:00 2001 From: Nanne Baars Date: Thu, 1 Jun 2023 08:47:31 +0200 Subject: [PATCH] VaultTransitOperations with versioned keys Closes gh-726 --- .../vault/core/VaultTransitTemplate.java | 6 +- .../vault/support/VaultTransitContext.java | 36 +++++-- .../VaultTransitTemplateIntegrationTests.java | 97 ++++++++++++++++--- 3 files changed, 114 insertions(+), 25 deletions(-) diff --git a/spring-vault-core/src/main/java/org/springframework/vault/core/VaultTransitTemplate.java b/spring-vault-core/src/main/java/org/springframework/vault/core/VaultTransitTemplate.java index ad0e374f4..5cb5887e2 100644 --- a/spring-vault-core/src/main/java/org/springframework/vault/core/VaultTransitTemplate.java +++ b/spring-vault-core/src/main/java/org/springframework/vault/core/VaultTransitTemplate.java @@ -16,8 +16,6 @@ package org.springframework.vault.core; import com.fasterxml.jackson.annotation.JsonProperty; -import org.jetbrains.annotations.NotNull; - import org.springframework.lang.Nullable; import org.springframework.util.Assert; import org.springframework.util.ObjectUtils; @@ -478,6 +476,10 @@ static void applyTransitOptions(VaultTransitContext context, Map if (!ObjectUtils.isEmpty(context.getNonce())) { request.put("nonce", Base64.getEncoder().encodeToString(context.getNonce())); } + + if (context.getKeyVersion() != 0) { + request.put("key_version", "" + context.getKeyVersion()); + } } static List toEncryptionResults(VaultResponse vaultResponse, List batchRequest) { diff --git a/spring-vault-core/src/main/java/org/springframework/vault/support/VaultTransitContext.java b/spring-vault-core/src/main/java/org/springframework/vault/support/VaultTransitContext.java index f3957bd25..ab65e821e 100644 --- a/spring-vault-core/src/main/java/org/springframework/vault/support/VaultTransitContext.java +++ b/spring-vault-core/src/main/java/org/springframework/vault/support/VaultTransitContext.java @@ -15,10 +15,10 @@ */ package org.springframework.vault.support; -import java.util.Arrays; - import org.springframework.util.Assert; +import java.util.Arrays; + /** * Transit backend encryption/decryption/rewrapping context. * @@ -30,15 +30,18 @@ public class VaultTransitContext { * Empty (default) {@link VaultTransitContext} without a {@literal context} and * {@literal nonce}. */ - private static final VaultTransitContext EMPTY = new VaultTransitContext(new byte[0], new byte[0]); + private static final VaultTransitContext EMPTY = new VaultTransitContext(new byte[0], new byte[0], 0); private final byte[] context; private final byte[] nonce; - VaultTransitContext(byte[] context, byte[] nonce) { + private final int keyVersion; + + VaultTransitContext(byte[] context, byte[] nonce, int keyVersion) { this.context = context; this.nonce = nonce; + this.keyVersion = keyVersion; } /** @@ -89,6 +92,15 @@ public byte[] getNonce() { return this.nonce; } + /** + * @return the version of the key to use for the operation. If not set, uses the + * latest version. Must be greater than or equal to the key's min_encryption_version, + * if set. + */ + public int getKeyVersion() { + return this.keyVersion; + } + @Override public boolean equals(Object o) { if (this == o) @@ -96,13 +108,14 @@ public boolean equals(Object o) { if (!(o instanceof VaultTransitContext)) return false; VaultTransitContext that = (VaultTransitContext) o; - return Arrays.equals(this.context, that.context) && Arrays.equals(this.nonce, that.nonce); + return Arrays.equals(this.context, that.context) && Arrays.equals(this.nonce, that.nonce) + && this.keyVersion == that.keyVersion; } @Override public int hashCode() { int result = Arrays.hashCode(this.context); - result = 31 * result + Arrays.hashCode(this.nonce); + result = 31 * result + Arrays.hashCode(this.nonce) + this.keyVersion; return result; } @@ -115,6 +128,8 @@ public static class VaultTransitRequestBuilder { private byte[] nonce = new byte[0]; + private int keyVersion; + VaultTransitRequestBuilder() { } @@ -149,12 +164,19 @@ public VaultTransitRequestBuilder nonce(byte[] nonce) { return this; } + public VaultTransitRequestBuilder keyVersion(int keyVersion) { + Assert.isTrue(keyVersion >= 0, "Key version must have a positive value"); + + this.keyVersion = keyVersion; + return this; + } + /** * Build a new {@link VaultTransitContext} instance. * @return a new {@link VaultTransitContext}. */ public VaultTransitContext build() { - return new VaultTransitContext(this.context, this.nonce); + return new VaultTransitContext(this.context, this.nonce, this.keyVersion); } } diff --git a/spring-vault-core/src/test/java/org/springframework/vault/core/VaultTransitTemplateIntegrationTests.java b/spring-vault-core/src/test/java/org/springframework/vault/core/VaultTransitTemplateIntegrationTests.java index 2c204c33a..eb7c3b1c2 100644 --- a/spring-vault-core/src/test/java/org/springframework/vault/core/VaultTransitTemplateIntegrationTests.java +++ b/spring-vault-core/src/test/java/org/springframework/vault/core/VaultTransitTemplateIntegrationTests.java @@ -15,25 +15,48 @@ */ package org.springframework.vault.core; -import java.util.Arrays; -import java.util.Collections; -import java.util.List; - +import org.assertj.core.api.Assertions; import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.ExtendWith; - +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.Arguments; +import org.junit.jupiter.params.provider.MethodSource; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.test.context.ContextConfiguration; import org.springframework.test.context.junit.jupiter.SpringExtension; import org.springframework.vault.VaultException; -import org.springframework.vault.support.*; +import org.springframework.vault.support.Ciphertext; +import org.springframework.vault.support.Hmac; +import org.springframework.vault.support.Plaintext; +import org.springframework.vault.support.RawTransitKey; +import org.springframework.vault.support.Signature; +import org.springframework.vault.support.SignatureValidation; +import org.springframework.vault.support.TransitKeyType; +import org.springframework.vault.support.VaultDecryptionResult; +import org.springframework.vault.support.VaultEncryptionResult; +import org.springframework.vault.support.VaultHmacRequest; +import org.springframework.vault.support.VaultMount; +import org.springframework.vault.support.VaultSignRequest; +import org.springframework.vault.support.VaultSignatureVerificationRequest; +import org.springframework.vault.support.VaultTransitContext; +import org.springframework.vault.support.VaultTransitKey; +import org.springframework.vault.support.VaultTransitKeyConfiguration; +import org.springframework.vault.support.VaultTransitKeyCreationRequest; import org.springframework.vault.util.IntegrationTestSupport; import org.springframework.vault.util.RequiresVaultVersion; import org.springframework.vault.util.Version; -import static org.assertj.core.api.Assertions.*; +import java.util.Arrays; +import java.util.Collections; +import java.util.List; +import java.util.stream.IntStream; +import java.util.stream.Stream; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatExceptionOfType; +import static org.assertj.core.api.Assertions.fail; /** * Integration tests for {@link VaultTransitTemplate} through @@ -311,19 +334,29 @@ void encryptShouldCreateCiphertext() { assertThat(ciphertext).startsWith("vault:v"); } - @Test - void encryptShouldCreateCiphertextWithNonceAndContext() { + private static Stream<Arguments> encryptWithKeyVersion() { + return Stream.of(Arguments.of(1, 1, "v1"), Arguments.of(2, 2, "v2"), Arguments.of(1, 2, ""), + Arguments.of(2, 1, "v1"), Arguments.of("2", "0", "v2")); + } - this.transitOperations.createKey("mykey", - VaultTransitKeyCreationRequest.builder().convergentEncryption(true).derived(true).build()); + @ParameterizedTest + @MethodSource + void encryptWithKeyVersion(int keyVersion, int usedKeyVersionWhileEncrypting, String expectedKeyPrefix) { + this.transitOperations.createKey("mykey", VaultTransitKeyCreationRequest.builder().build()); + // rotate the key to get the right version + IntStream.range(0, keyVersion - 1).forEach(__ -> this.transitOperations.rotate("mykey")); VaultTransitContext transitRequest = VaultTransitContext.builder() - .context("blubb".getBytes()) // - .nonce("123456789012".getBytes()) // + .keyVersion(usedKeyVersionWhileEncrypting) .build(); - String ciphertext = this.transitOperations.encrypt("mykey", "hello-world".getBytes(), transitRequest); - assertThat(ciphertext).startsWith("vault:v1:"); + try { + String ciphertext = this.transitOperations.encrypt("mykey", "hello-world".getBytes(), transitRequest); + assertThat(ciphertext).startsWith("vault:%s:".formatted(expectedKeyPrefix)); + } + catch (Exception e) { + Assertions.assertThat(expectedKeyPrefix).isNullOrEmpty(); + } } @Test @@ -372,6 +405,38 @@ void decryptShouldCreatePlaintext() { assertThat(plaintext).isEqualTo("hello-world"); } + private static Stream<Arguments> decryptWithKeyVersion() { + return Stream.of(Arguments.of(1, 1, true), Arguments.of(2, 2, true), Arguments.of(1, 2, false), + Arguments.of(2, 1, true), Arguments.of("2", "0", true)); + } + + @ParameterizedTest + @MethodSource + void decryptWithKeyVersion(int keyVersion, int usedKeyVersionWhileEncrypting, boolean shouldPass) { + this.transitOperations.createKey("mykey"); + // rotate the key to get the right version + IntStream.range(0, keyVersion - 1).forEach(__ -> this.transitOperations.rotate("mykey")); + + VaultTransitContext transitRequest = VaultTransitContext.builder() + .keyVersion(usedKeyVersionWhileEncrypting) + .build(); + + try { + String ciphertext = this.transitOperations + .encrypt("mykey", Plaintext.of("hello-world").with(transitRequest)) + .getCiphertext(); + String plaintext = Plaintext.of(this.transitOperations.decrypt("mykey", ciphertext, transitRequest)) + .asString(); + + assertThat(shouldPass).isTrue(); + assertThat(plaintext).isEqualTo("hello-world"); + + } + catch (VaultException e) { + assertThat(shouldPass).isFalse(); + } + } + @Test void decryptShouldCreatePlaintextWithNonceAndContext() { @@ -564,7 +629,7 @@ void shouldBatchDecryptWithWrongContext() { } catch (VaultException e) { assertThat(e).hasMessageContaining("error"); // Vault 1.6 behavior is - // different + // different } }