Skip to content

Commit

Permalink
Multi output partition issues
Browse files Browse the repository at this point in the history
 - When using reactive functions, partition selector strategy does not
   use the configured partition count for multiple outbounds. This is
   because we take the first configured output binding and apply it's
   partition counts on all the outbound reactive streams (Tuples).
   Addressing this issue by properly applying the correct partition handling
   per output binding.

Resolves spring-cloud#2750
  • Loading branch information
sobychacko committed Jun 12, 2023
1 parent abf5aab commit 0f00684
Show file tree
Hide file tree
Showing 4 changed files with 200 additions and 6 deletions.
@@ -0,0 +1,148 @@
/*
* Copyright 2023-2023 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.springframework.cloud.stream.binder.kafka;

import java.util.Map;
import java.util.function.Function;

import org.apache.kafka.clients.consumer.Consumer;
import org.apache.kafka.clients.consumer.ConsumerConfig;
import org.apache.kafka.clients.consumer.ConsumerRecord;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;
import reactor.core.publisher.Flux;
import reactor.core.publisher.UnicastProcessor;
import reactor.util.function.Tuple2;
import reactor.util.function.Tuples;

import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.boot.WebApplicationType;
import org.springframework.boot.autoconfigure.EnableAutoConfiguration;
import org.springframework.boot.builder.SpringApplicationBuilder;
import org.springframework.cloud.stream.binder.PartitionSelectorStrategy;
import org.springframework.cloud.stream.function.StreamBridge;
import org.springframework.context.ConfigurableApplicationContext;
import org.springframework.context.annotation.Bean;
import org.springframework.integration.support.MessageBuilder;
import org.springframework.kafka.core.DefaultKafkaConsumerFactory;
import org.springframework.kafka.test.EmbeddedKafkaBroker;
import org.springframework.kafka.test.context.EmbeddedKafka;
import org.springframework.kafka.test.utils.KafkaTestUtils;
import org.springframework.test.annotation.DirtiesContext;
import org.springframework.test.context.junit.jupiter.SpringExtension;

import static org.assertj.core.api.Assertions.assertThat;

/***
* @author Soby Chacko
*/
@ExtendWith(SpringExtension.class)
@DirtiesContext(classMode = DirtiesContext.ClassMode.AFTER_EACH_TEST_METHOD)
@EmbeddedKafka(topics = { "odd-topic", "even-topic" })
class MultipleOutputBindingsPartitionsTests {

@Autowired
private EmbeddedKafkaBroker embeddedKafka;

@Test
void singleInputTupleOutputWithDifferentPartitions() {
try (ConfigurableApplicationContext context = new SpringApplicationBuilder(MultiOutputApplication.class)
.web(WebApplicationType.NONE).run(
"--server.port=0",
"--spring.jmx.enabled=false",
"--spring.kafka.consumer.metadata.max.age.ms=1000",
"--spring.cloud.function.definition=singleInputMultipleOutputs",
"--spring.cloud.stream.function.reactive.singleInputMultipleOutputs=true",
"--spring.cloud.stream.kafka.binder.autoAddPartitions=true",
"--spring.cloud.stream.bindings.singleInputMultipleOutputs-in-0.group=grp5",
"--spring.cloud.stream.bindings.singleInputMultipleOutputs-in-0.destination=multi-input",
"--spring.cloud.stream.bindings.singleInputMultipleOutputs-out-0.destination=odd-topic",
"--spring.cloud.stream.bindings.singleInputMultipleOutputs-out-1.destination=even-topic",
"--spring.cloud.stream.bindings.singleInputMultipleOutputs-out-0.producer.partitionKeyExpression=payload",
"--spring.cloud.stream.bindings.singleInputMultipleOutputs-out-0.producer.partitionCount=10",
"--spring.cloud.stream.bindings.singleInputMultipleOutputs-out-0.producer.partitionSelectorName=mySelectorStrategy",
"--spring.cloud.stream.bindings.singleInputMultipleOutputs-out-1.producer.partitionKeyExpression=payload",
"--spring.cloud.stream.bindings.singleInputMultipleOutputs-out-1.producer.partitionCount=5",
"--spring.cloud.stream.bindings.singleInputMultipleOutputs-out-1.producer.partitionSelectorName=mySelectorStrategy",
"--spring.cloud.stream.kafka.binder.brokers=" + embeddedKafka.getBrokersAsString())) {

StreamBridge streamBridge = context.getBean(StreamBridge.class);
streamBridge.send("multi-input", MessageBuilder.withPayload(101)
.build());
streamBridge.send("multi-input", MessageBuilder.withPayload(102)
.build());

Map<String, Object> consumerProps = KafkaTestUtils.consumerProps("group3", "false", embeddedKafka);
consumerProps.put(ConsumerConfig.AUTO_OFFSET_RESET_CONFIG, "earliest");
DefaultKafkaConsumerFactory<String, String> cf = new DefaultKafkaConsumerFactory<>(consumerProps);
Consumer<String, String> consumer1 = cf.createConsumer();
embeddedKafka.consumeFromEmbeddedTopics(consumer1, "odd-topic");
Consumer<String, String> consumer2 = cf.createConsumer("group4", null);
embeddedKafka.consumeFromEmbeddedTopics(consumer2, "even-topic");

ConsumerRecord<String, String> record1 = KafkaTestUtils.getSingleRecord(consumer1, "odd-topic");
assertThat(record1)
.isNotNull()
.extracting(ConsumerRecord::value)
.isEqualTo("ODD: 101");
assertThat(record1)
.extracting(ConsumerRecord::partition)
.isEqualTo(9);

ConsumerRecord<String, String> record2 = KafkaTestUtils.getSingleRecord(consumer2, "even-topic");
assertThat(record2)
.isNotNull()
.extracting(ConsumerRecord::value)
.isEqualTo("EVEN: 102");
assertThat(record2)
.extracting(ConsumerRecord::partition)
.isEqualTo(4);
}
}


@EnableAutoConfiguration
public static class MultiOutputApplication {

@Bean
@SuppressWarnings({ "unchecked", "rawtypes" })
public static Function<Flux<Integer>, Tuple2<Flux<String>, Flux<String>>> singleInputMultipleOutputs() {
return flux -> {
Flux<Integer> connectedFlux = flux.publish().autoConnect(2);
UnicastProcessor odd = UnicastProcessor.create();
UnicastProcessor even = UnicastProcessor.create();
Flux<Integer> oddFlux = connectedFlux.filter(number -> number % 2 != 0).doOnNext(number -> odd.onNext("ODD: " + number));
Flux<Integer> evenFlux = connectedFlux.filter(number -> number % 2 == 0).doOnNext(number -> even.onNext("EVEN: " + number));
return Tuples.of(Flux.from(odd).doOnSubscribe(x -> oddFlux.subscribe()), Flux.from(even).doOnSubscribe(x -> evenFlux.subscribe()));
};
}

@Bean
public PartitionSelectorStrategy mySelectorStrategy() {
return new MyPartitionSelector();
}
}

static class MyPartitionSelector implements PartitionSelectorStrategy {

@Override
public int selectPartition(Object key, int partitionCount) {
// selecting the last partition for easy test verification.
return partitionCount - 1;
}
}
}
Expand Up @@ -173,7 +173,11 @@ void testSingleInputMultiOutput() {
ReactiveFunctionConfiguration.class))
.web(WebApplicationType.NONE)
.run("--spring.jmx.enabled=false",
"--spring.cloud.function.definition=singleInputMultipleOutputs")) {
"--spring.cloud.function.definition=singleInputMultipleOutputs",
"--spring.cloud.stream.bindings.singleInputMultipleOutputs-out-0.producer.partitionKeyExpression=payload",
"--spring.cloud.stream.bindings.singleInputMultipleOutputs-out-0.producer.partitionCount=10",
"--spring.cloud.stream.bindings.singleInputMultipleOutputs-out-1.producer.partitionKeyExpression=payload",
"--spring.cloud.stream.bindings.singleInputMultipleOutputs-out-1.producer.partitionCount=10")) {
InputDestination inputDestination = context.getBean(InputDestination.class);
OutputDestination outputDestination = context.getBean(OutputDestination.class);

Expand Down
Expand Up @@ -33,6 +33,7 @@
import java.util.concurrent.atomic.AtomicReference;
import java.util.function.Function;
import java.util.function.Supplier;
import java.util.stream.StreamSupport;

import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
Expand Down Expand Up @@ -63,8 +64,10 @@
import org.springframework.cloud.function.context.config.RoutingFunction;
import org.springframework.cloud.function.context.message.MessageUtils;
import org.springframework.cloud.stream.binder.BinderFactory;
import org.springframework.cloud.stream.binder.BinderHeaders;
import org.springframework.cloud.stream.binder.BindingCreatedEvent;
import org.springframework.cloud.stream.binder.ConsumerProperties;
import org.springframework.cloud.stream.binder.PartitionHandler;
import org.springframework.cloud.stream.binder.ProducerProperties;
import org.springframework.cloud.stream.binder.ProducerProperties.PollerProperties;
import org.springframework.cloud.stream.binding.BindableProxyFactory;
Expand All @@ -86,12 +89,14 @@
import org.springframework.core.env.ConfigurableEnvironment;
import org.springframework.core.env.Environment;
import org.springframework.core.type.MethodMetadata;
import org.springframework.expression.spel.support.StandardEvaluationContext;
import org.springframework.integration.channel.AbstractMessageChannel;
import org.springframework.integration.channel.AbstractSubscribableChannel;
import org.springframework.integration.channel.FluxMessageChannel;
import org.springframework.integration.core.MessagingTemplate;
import org.springframework.integration.dsl.IntegrationFlow;
import org.springframework.integration.dsl.IntegrationFlowBuilder;
import org.springframework.integration.expression.ExpressionUtils;
import org.springframework.integration.handler.AbstractMessageHandler;
import org.springframework.integration.scheduling.PollerMetadata;
import org.springframework.integration.support.MessageBuilder;
Expand Down Expand Up @@ -553,20 +558,34 @@ private void bindFunctionToDestinations(BindableProxyFactory bindableProxyFactor
function.setSkipOutputConversion(producerProperties.isUseNativeEncoding());
}
functionToInvoke = new PartitionAwareFunctionWrapper(function, this.applicationContext, producerProperties);
// If we have a multi-output scenario, we will do any message enrichment (aka, determining the outbound
// partition) via the corresponding reactive Flux types. Currently, we support multiple output
// bindings for reactive types only (Tuples).
if (outputBindingNames.size() > 1) {
((PartitionAwareFunctionWrapper) functionToInvoke).setMessageEnricherEnabled(false);
}
}

Object resultPublishers = functionToInvoke.apply(inputPublishers.length == 1 ? inputPublishers[0] : Tuples.fromArray(inputPublishers));

if (!(resultPublishers instanceof Iterable)) {
resultPublishers = Collections.singletonList(resultPublishers);
}
Iterator<String> outputBindingIter = outputBindingNames.iterator();
long outputCount = StreamSupport.stream(((Iterable) resultPublishers).spliterator(), false).count();

((Iterable) resultPublishers).forEach(publisher -> {
Flux flux = Flux.from((Publisher) publisher);
if (!CollectionUtils.isEmpty(outputBindingNames)) {
MessageChannel outputChannel = this.applicationContext.getBean(outputBindingIter.next(), MessageChannel.class);
String outputBinding = outputBindingIter.next();
MessageChannel outputChannel = this.applicationContext.getBean(outputBinding, MessageChannel.class);
flux = flux.doOnNext(message -> {
// If there are more than 1 output bindings, then ensure that we properly calculate the partitions
// based on information from the correct output binding properties.
if (outputCount > 1) {
Integer partitionId = determinePartitionForOutputBinding(outputBinding, message);
message = MessageBuilder
.fromMessage((Message<?>) message)
.setHeader(BinderHeaders.PARTITION_HEADER, partitionId).build();
}
if (message instanceof Message m && m.getHeaders().get("spring.cloud.stream.sendto.destination") != null) {
String destinationName = (String) m.getHeaders().get("spring.cloud.stream.sendto.destination");
ProducerProperties producerProperties = this.serviceProperties.getBindings().get(outputBindingNames.iterator().next()).getProducer();
Expand Down Expand Up @@ -606,6 +625,19 @@ private void bindFunctionToDestinations(BindableProxyFactory bindableProxyFactor
}
}

private Integer determinePartitionForOutputBinding(String outputBinding, Object message) {
BindingProperties bindingProperties = FunctionToDestinationBinder.this.serviceProperties.getBindings().get(outputBinding);
ProducerProperties producerProperties = bindingProperties == null ? null : bindingProperties.getProducer();
if (producerProperties != null && producerProperties.isPartitioned()) {
StandardEvaluationContext evaluationContext = ExpressionUtils.createStandardEvaluationContext(this.applicationContext.getBeanFactory());
PartitionHandler partitionHandler = new PartitionHandler(evaluationContext, producerProperties, this.applicationContext.getBeanFactory());
if (message instanceof Message) {
return partitionHandler.determinePartition((Message<?>) message);
}
}
return null;
}

private AbstractMessageHandler createFunctionHandler(FunctionInvocationWrapper function,
String inputChannelName, String outputChannelName) {
ConsumerProperties consumerProperties = StringUtils.hasText(inputChannelName)
Expand Down
Expand Up @@ -51,6 +51,8 @@ class PartitionAwareFunctionWrapper implements Function<Object, Object>, Supplie

private final Function<Object, Object> outputMessageEnricher;

private boolean messageEnricherEnabled = true;

PartitionAwareFunctionWrapper(Function<?, ?> function, ConfigurableApplicationContext context, ProducerProperties producerProperties) {
this.function = function;

Expand Down Expand Up @@ -84,7 +86,9 @@ private Message<?> toMessageWithPartitionHeader(Message message, PartitionHandle
@SuppressWarnings("unchecked")
@Override
public Object apply(Object input) {
this.setEnhancerIfNecessary();
if (this.messageEnricherEnabled) {
this.setEnhancerIfNecessary();
}
Object result = this.function.apply(input);
if (!((FunctionInvocationWrapper) this.function).isInputTypePublisher()) {
((FunctionInvocationWrapper) this.function).setEnhancer(null);
Expand All @@ -95,7 +99,9 @@ public Object apply(Object input) {
@Override
public Object get() {
if (this.function instanceof FunctionInvocationWrapper functionInvocationWrapper) {
this.setEnhancerIfNecessary();
if (this.messageEnricherEnabled) {
this.setEnhancerIfNecessary();
}
return functionInvocationWrapper.get();
}
throw new IllegalStateException("Call to get() is not allowed since this function is not a Supplier.");
Expand All @@ -106,4 +112,8 @@ private void setEnhancerIfNecessary() {
functionInvocationWrapper.setEnhancer(this.outputMessageEnricher);
}
}

public void setMessageEnricherEnabled(boolean messageEnricherEnabled) {
this.messageEnricherEnabled = messageEnricherEnabled;
}
}

0 comments on commit 0f00684

Please sign in to comment.