Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
* The {@code partitionKeyFunction} is used to determine to which partition the message
* has to be dispatched.
* By default, the {@link IntegrationMessageHeaderAccessor#CORRELATION_ID} message header is used
* for partition key.
* for a partition key.
* <p>
* The actual dispatching and threading logic is implemented in the {@link PartitionedDispatcher}.
* <p>
Expand Down Expand Up @@ -71,7 +71,7 @@ public PartitionedChannel(int partitionCount) {
}

/**
* Instantiate based on a provided number of partitions and function for partition key against
* Instantiate based on a provided number of partitions and function for a partition key against
* the message.
* @param partitionCount the number of partitions in this channel.
* @param partitionKeyFunction the function to resolve a partition key against the message
Expand Down Expand Up @@ -123,6 +123,16 @@ public void setLoadBalancingStrategy(@Nullable LoadBalancingStrategy loadBalanci
getDispatcher().setLoadBalancingStrategy(loadBalancingStrategy);
}

/**
* Provide a size of the queue in the partition executor's worker.
* Default to zero.
* @param workerQueueSize the size of the partition executor's worker queue.
* @since 6.4.10
*/
public void setWorkerQueueSize(int workerQueueSize) {
getDispatcher().setWorkerQueueSize(workerQueueSize);
}

@Override
protected PartitionedDispatcher getDispatcher() {
return (PartitionedDispatcher) this.dispatcher;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,17 +19,22 @@
import java.util.ArrayList;
import java.util.List;
import java.util.Set;
import java.util.concurrent.BlockingQueue;
import java.util.concurrent.Executor;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.LinkedBlockingQueue;
import java.util.concurrent.SynchronousQueue;
import java.util.concurrent.ThreadFactory;
import java.util.concurrent.ThreadPoolExecutor;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.locks.Lock;
import java.util.concurrent.locks.ReentrantLock;
import java.util.function.Function;
import java.util.function.Predicate;

import org.jspecify.annotations.Nullable;

import org.springframework.integration.util.CallerBlocksPolicy;
import org.springframework.integration.util.ErrorHandlingTaskExecutor;
import org.springframework.messaging.Message;
import org.springframework.messaging.MessageHandler;
Expand Down Expand Up @@ -78,6 +83,8 @@ public class PartitionedDispatcher extends AbstractDispatcher {

private final Lock lock = new ReentrantLock();

private int workerQueueSize;

/**
* Instantiate based on a provided number of partitions and function for a partition key against
* the message to dispatch.
Expand Down Expand Up @@ -153,6 +160,17 @@ public void setMessageHandlingTaskDecorator(MessageHandlingTaskDecorator message
this.messageHandlingTaskDecorator = messageHandlingTaskDecorator;
}

/**
* Provide a size of the queue in the partition executor's worker.
* Default to zero.
* @param workerQueueSize the size of the partition executor's worker queue.
* @since 6.4.10
*/
public void setWorkerQueueSize(int workerQueueSize) {
Assert.isTrue(workerQueueSize >= 0, "'workerQueueSize' must be greater than or equal to 0.");
this.workerQueueSize = workerQueueSize;
}

/**
* Shutdown this dispatcher on application close.
* The partition executors are shutdown and the internal state of this instance is cleared.
Expand Down Expand Up @@ -188,7 +206,16 @@ private void populatedPartitions() {
}

private UnicastingDispatcher newPartition() {
ExecutorService executor = Executors.newSingleThreadExecutor(this.threadFactory);
BlockingQueue<Runnable> workQueue =
this.workerQueueSize == 0
? new SynchronousQueue<>()
: new LinkedBlockingQueue<>(this.workerQueueSize);
ExecutorService executor =
new ThreadPoolExecutor(1, 1,
0L, TimeUnit.MILLISECONDS,
workQueue,
this.threadFactory,
new CallerBlocksPolicy(Long.MAX_VALUE));
this.executors.add(executor);

Executor effectiveExecutor = this.errorHandler != null
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,8 @@ public class PartitionedChannelSpec extends LoadBalancingChannelSpec<Partitioned

private @Nullable ThreadFactory threadFactory;

private int workerQueueSize;

protected PartitionedChannelSpec(int partitionCount) {
this.partitionCount = partitionCount;
}
Expand All @@ -53,6 +55,18 @@ public PartitionedChannelSpec threadFactory(ThreadFactory threadFactory) {
return this;
}

/**
* Provide a size of the queue in the partition executor's worker.
* Default to zero.
* @param workerQueueSize the size of the partition executor's worker queue.
* @return the spec.
* @since 6.4.10
*/
public PartitionedChannelSpec workerQueueSize(int workerQueueSize) {
this.workerQueueSize = workerQueueSize;
return this;
}

@Override
protected PartitionedChannel doGet() {
if (this.partitionKeyFunction != null) {
Expand All @@ -62,6 +76,7 @@ protected PartitionedChannel doGet() {
this.channel = new PartitionedChannel(this.partitionCount);
}
this.channel.setLoadBalancingStrategy(this.loadBalancingStrategy);
this.channel.setWorkerQueueSize(this.workerQueueSize);
if (this.failoverStrategy != null) {
this.channel.setFailoverStrategy(this.failoverStrategy);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,13 @@

import java.util.Collection;
import java.util.HashSet;
import java.util.List;
import java.util.Set;
import java.util.concurrent.BlockingQueue;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.LinkedBlockingQueue;
import java.util.concurrent.SynchronousQueue;
import java.util.concurrent.ThreadPoolExecutor;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.locks.Lock;
import java.util.concurrent.locks.ReentrantLock;
Expand All @@ -36,6 +41,7 @@
import org.springframework.integration.config.EnableIntegration;
import org.springframework.integration.dsl.IntegrationFlow;
import org.springframework.integration.support.MessageBuilder;
import org.springframework.integration.test.util.TestUtils;
import org.springframework.messaging.Message;
import org.springframework.messaging.MessageChannel;
import org.springframework.messaging.MessageHandler;
Expand All @@ -48,6 +54,7 @@
import org.springframework.util.MultiValueMap;

import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.InstanceOfAssertFactories.type;
import static org.mockito.Mockito.mock;

/**
Expand Down Expand Up @@ -128,6 +135,11 @@ public void afterMessageHandled(Message<?> message, MessageChannel ch, MessageHa
String partitionForLastMessage = partitionedMessages.keySet().iterator().next();
assertThat(partitionForLastMessage).isIn(allocatedPartitions);

List<?> partitionExecutors = TestUtils.getPropertyValue(partitionedChannel, "dispatcher.executors", List.class);
BlockingQueue<?> workQueue = ((ThreadPoolExecutor) partitionExecutors.get(0)).getQueue();

assertThat(workQueue).isInstanceOf(SynchronousQueue.class);

partitionedChannel.destroy();
}

Expand All @@ -138,6 +150,9 @@ public void afterMessageHandled(Message<?> message, MessageChannel ch, MessageHa
@Autowired
PollableChannel resultChannel;

@Autowired
PartitionedChannel testChannel;

@Test
void messagesArePartitionedByCorrelationId() {
this.inputChannel.send(new GenericMessage<>(IntStream.range(0, 5).toArray()));
Expand All @@ -153,6 +168,14 @@ void messagesArePartitionedByCorrelationId() {
Set<String> strings = new HashSet<>((Collection<? extends String>) receive.getPayload());
assertThat(strings).hasSize(1)
.allMatch(value -> value.startsWith("testChannel-partition-thread-"));

List<?> partitionExecutors = TestUtils.getPropertyValue(this.testChannel, "dispatcher.executors", List.class);
BlockingQueue<?> workQueue = ((ThreadPoolExecutor) partitionExecutors.get(0)).getQueue();

assertThat(workQueue)
.asInstanceOf(type(LinkedBlockingQueue.class))
.extracting(LinkedBlockingQueue::remainingCapacity)
.isEqualTo(1);
}

@Configuration
Expand All @@ -163,7 +186,7 @@ public static class TestConfiguration {
IntegrationFlow someFlow() {
return f -> f
.split()
.channel(c -> c.partitioned("testChannel", 10))
.channel(c -> c.partitioned("testChannel", 10).workerQueueSize(1))
.transform(p -> Thread.currentThread().getName())
.aggregate()
.channel(c -> c.queue("resultChannel"));
Expand Down