diff --git a/spring-rabbit-stream/src/main/java/org/springframework/rabbit/stream/config/SuperStream.java b/spring-rabbit-stream/src/main/java/org/springframework/rabbit/stream/config/SuperStream.java index 73fe76a9ff..0d74183d79 100644 --- a/spring-rabbit-stream/src/main/java/org/springframework/rabbit/stream/config/SuperStream.java +++ b/spring-rabbit-stream/src/main/java/org/springframework/rabbit/stream/config/SuperStream.java @@ -20,6 +20,8 @@ import java.util.Collection; import java.util.List; import java.util.Map; +import java.util.function.BiFunction; +import java.util.stream.Collectors; import java.util.stream.IntStream; import org.springframework.amqp.core.Binding; @@ -28,6 +30,7 @@ import org.springframework.amqp.core.Declarables; import org.springframework.amqp.core.DirectExchange; import org.springframework.amqp.core.Queue; +import org.springframework.util.Assert; /** * Create Super Stream Topology {@link Declarable}s. @@ -44,16 +47,33 @@ public class SuperStream extends Declarables { * @param partitions the number of partitions. */ public SuperStream(String name, int partitions) { - super(declarables(name, partitions)); + this(name, partitions, (q, i) -> IntStream.range(0, i) + .mapToObj(String::valueOf) + .collect(Collectors.toList())); } - private static Collection declarables(String name, int partitions) { + /** + * Create a Super Stream with the provided parameters. + * @param name the stream name. + * @param partitions the number of partitions. + * @param routingKeyStrategy a strategy to determine routing keys to use for the + * partitions. The first parameter is the queue name, the second the number of + * partitions, the returned list must have a size equal to the partitions. + */ + public SuperStream(String name, int partitions, BiFunction> routingKeyStrategy) { + super(declarables(name, partitions, routingKeyStrategy)); + } + + private static Collection declarables(String name, int partitions, + BiFunction> routingKeyStrategy) { + List declarables = new ArrayList<>(); - String[] rks = IntStream.range(0, partitions).mapToObj(String::valueOf).toArray(String[]::new); + List rks = routingKeyStrategy.apply(name, partitions); + Assert.state(rks.size() == partitions, () -> "Expected " + partitions + " routing keys, not " + rks.size()); declarables.add(new DirectExchange(name, true, false, Map.of("x-super-stream", true))); for (int i = 0; i < partitions; i++) { - String rk = rks[i]; - Queue q = new Queue(name + "-" + rk, true, false, false, Map.of("x-queue-type", "stream")); + String rk = rks.get(i); + Queue q = new Queue(name + "-" + i, true, false, false, Map.of("x-queue-type", "stream")); declarables.add(q); declarables.add(new Binding(q.getName(), DestinationType.QUEUE, name, rk, Map.of("x-stream-partition-order", i))); diff --git a/spring-rabbit-stream/src/test/java/org/springframework/rabbit/stream/listener/SuperStreamSACTests.java b/spring-rabbit-stream/src/test/java/org/springframework/rabbit/stream/listener/SuperStreamSACTests.java index e9d3a28ced..2dd1a4614e 100644 --- a/spring-rabbit-stream/src/test/java/org/springframework/rabbit/stream/listener/SuperStreamSACTests.java +++ b/spring-rabbit-stream/src/test/java/org/springframework/rabbit/stream/listener/SuperStreamSACTests.java @@ -24,6 +24,8 @@ import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.CountDownLatch; import java.util.concurrent.TimeUnit; +import java.util.stream.Collectors; +import java.util.stream.IntStream; import org.junit.jupiter.api.Test; @@ -69,9 +71,9 @@ void superStream(@Autowired ApplicationContext context, @Autowired RabbitTemplat container2.start(); StreamListenerContainer container3 = context.getBean(StreamListenerContainer.class, env, "three"); container3.start(); - template.convertAndSend("ss.sac.test", "0", "foo"); - template.convertAndSend("ss.sac.test", "1", "bar"); - template.convertAndSend("ss.sac.test", "2", "baz"); + template.convertAndSend("ss.sac.test", "rk-0", "foo"); + template.convertAndSend("ss.sac.test", "rk-1", "bar"); + template.convertAndSend("ss.sac.test", "rk-2", "baz"); assertThat(config.latch.await(10, TimeUnit.SECONDS)).isTrue(); assertThat(config.messages.keySet()).contains("one", "two", "three"); assertThat(config.info).contains("one:foo", "two:bar", "three:baz"); @@ -112,7 +114,9 @@ RabbitTemplate template(ConnectionFactory cf) { @Bean SuperStream superStream() { - return new SuperStream("ss.sac.test", 3); + return new SuperStream("ss.sac.test", 3, (q, i) -> IntStream.range(0, i) + .mapToObj(j -> "rk-" + j) + .collect(Collectors.toList())); } @Bean