diff --git a/spring-kafka/src/main/java/org/springframework/kafka/support/DefaultKafkaHeaderMapper.java b/spring-kafka/src/main/java/org/springframework/kafka/support/DefaultKafkaHeaderMapper.java index b1b0a84f6c..099da67fd5 100644 --- a/spring-kafka/src/main/java/org/springframework/kafka/support/DefaultKafkaHeaderMapper.java +++ b/spring-kafka/src/main/java/org/springframework/kafka/support/DefaultKafkaHeaderMapper.java @@ -20,6 +20,7 @@ import java.nio.charset.StandardCharsets; import java.util.Arrays; import java.util.HashMap; +import java.util.HashSet; import java.util.LinkedHashSet; import java.util.List; import java.util.Map; @@ -61,6 +62,16 @@ public class DefaultKafkaHeaderMapper extends AbstractKafkaHeaderMapper { private static final String JAVA_LANG_STRING = "java.lang.String"; + private static final Set TRUSTED_ARRAY_TYPES = + new HashSet<>(Arrays.asList( + "[B", + "[I", + "[J", + "[F", + "[D", + "[C" + )); + private static final List DEFAULT_TRUSTED_PACKAGES = Arrays.asList( "java.lang", @@ -362,12 +373,16 @@ protected boolean trusted(String requestedType) { if (requestedType.equals(NonTrustedHeaderType.class.getName())) { return true; } + if (TRUSTED_ARRAY_TYPES.contains(requestedType)) { + return true; + } + String type = requestedType.startsWith("[") ? requestedType.substring(2) : requestedType; if (!this.trustedPackages.isEmpty()) { - int lastDot = requestedType.lastIndexOf('.'); + int lastDot = type.lastIndexOf('.'); if (lastDot < 0) { return false; } - String packageName = requestedType.substring(0, lastDot); + String packageName = type.substring(0, lastDot); for (String trustedPackage : this.trustedPackages) { if (packageName.equals(trustedPackage) || packageName.startsWith(trustedPackage + ".")) { return true; diff --git a/spring-kafka/src/test/java/org/springframework/kafka/support/DefaultKafkaHeaderMapperTests.java b/spring-kafka/src/test/java/org/springframework/kafka/support/DefaultKafkaHeaderMapperTests.java index 9057aff3ac..5b1fd3a158 100644 --- a/spring-kafka/src/test/java/org/springframework/kafka/support/DefaultKafkaHeaderMapperTests.java +++ b/spring-kafka/src/test/java/org/springframework/kafka/support/DefaultKafkaHeaderMapperTests.java @@ -66,10 +66,23 @@ void testTrustedAndNot() { .setHeader("simpleContentType", MimeTypeUtils.TEXT_PLAIN_VALUE) .setHeader("customToString", new Bar("fiz")) .setHeader("uri", URI.create("https://foo.bar")) + .setHeader("intA", new int[] { 42 }) + .setHeader("longA", new long[] { 42L }) + .setHeader("floatA", new float[] { 1.0f }) + .setHeader("doubleA", new double[] { 1.0 }) + .setHeader("charA", new char[] { 'c' }) + .setHeader("boolA", new boolean[] { true }) + .setHeader("IntA", new Integer[] { 42 }) + .setHeader("LongA", new Long[] { 42L }) + .setHeader("FloatA", new Float[] { 1.0f }) + .setHeader("DoubleA", new Double[] { 1.0 }) + .setHeader("CharA", new Character[] { 'c' }) + .setHeader("BoolA", new Boolean[] { true }) + .setHeader("stringA", new String[] { "array" }) .build(); RecordHeaders recordHeaders = new RecordHeaders(); mapper.fromHeaders(message.getHeaders(), recordHeaders); - assertThat(recordHeaders.toArray().length).isEqualTo(10); // 9 + json_types + assertThat(recordHeaders.toArray().length).isEqualTo(23); // 21 + json_types Map headers = new HashMap<>(); mapper.toHeaders(recordHeaders, headers); assertThat(headers.get("foo")).isInstanceOf(byte[].class); @@ -83,10 +96,21 @@ void testTrustedAndNot() { assertThat(headers.get(MessageHeaders.ERROR_CHANNEL)).isEqualTo("errors"); assertThat(headers.get("customToString")).isEqualTo("Bar [field=fiz]"); assertThat(headers.get("uri")).isEqualTo(URI.create("https://foo.bar")); + assertThat(headers.get("intA")).isEqualTo(new int[] { 42 }); + assertThat(headers.get("longA")).isEqualTo(new long[] { 42L }); + assertThat(headers.get("floatA")).isEqualTo(new float[] { 1.0f }); + assertThat(headers.get("doubleA")).isEqualTo(new double[] { 1.0 }); + assertThat(headers.get("charA")).isEqualTo(new char[] { 'c' }); + assertThat(headers.get("IntA")).isEqualTo(new Integer[] { 42 }); + assertThat(headers.get("LongA")).isEqualTo(new Long[] { 42L }); + assertThat(headers.get("FloatA")).isEqualTo(new Float[] { 1.0f }); + assertThat(headers.get("DoubleA")).isEqualTo(new Double[] { 1.0 }); + assertThat(headers.get("CharA")).isEqualTo(new Character[] { 'c' }); + assertThat(headers.get("stringA")).isEqualTo(new String[] { "array" }); NonTrustedHeaderType ntht = (NonTrustedHeaderType) headers.get("fix"); assertThat(ntht.getHeaderValue()).isNotNull(); assertThat(ntht.getUntrustedType()).isEqualTo(Foo.class.getName()); - assertThat(headers).hasSize(9); + assertThat(headers).hasSize(22); mapper.addTrustedPackages(getClass().getPackage().getName()); headers = new HashMap<>(); @@ -95,7 +119,7 @@ void testTrustedAndNot() { assertThat(new String((byte[]) headers.get("foo"))).isEqualTo("bar"); assertThat(headers.get("baz")).isEqualTo("qux"); assertThat(headers.get("fix")).isEqualTo(new Foo()); - assertThat(headers).hasSize(9); + assertThat(headers).hasSize(22); } @Test