Skip to content

Commit

Permalink
Handle STOMP messages to user destination in order
Browse files Browse the repository at this point in the history
Closes gh-31395
  • Loading branch information
rstoyanchev committed Oct 11, 2023
1 parent 9eb39e1 commit 3277b0d
Show file tree
Hide file tree
Showing 9 changed files with 218 additions and 24 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ written to WebSocket sessions. As the channel is backed by a `ThreadPoolExecutor
are processed in different threads, and the resulting sequence received by the client may
not match the exact order of publication.

If this is an issue, enable the `setPreservePublishOrder` flag, as the following example shows:
To enable ordered publishing, set the `setPreservePublishOrder` flag as follows:

[source,java,indent=0,subs="verbatim,quotes"]
----
Expand Down Expand Up @@ -47,5 +47,22 @@ When the flag is set, messages within the same client session are published to t
`clientOutboundChannel` one at a time, so that the order of publication is guaranteed.
Note that this incurs a small performance overhead, so you should enable it only if it is required.

The same also applies to messages from the client, which are sent to the `clientInboundChannel`,
from where they are handled according to their destination prefix. As the channel is backed by
a `ThreadPoolExecutor`, messages are processed in different threads, and the resulting sequence
of handling may not match the exact order in which they were received.

To enable ordered publishing, set the `setPreserveReceiveOrder` flag as follows:

[source,java,indent=0,subs="verbatim,quotes"]
----
@Configuration
@EnableWebSocketMessageBroker
public class MyConfig implements WebSocketMessageBrokerConfigurer {
@Override
public void registerStompEndpoints(StompEndpointRegistry registry) {
registry.setPreserveReceiveOrder(true);
}
}
----
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@ protected void doSend(String destination, Message<?> message) {
if (simpAccessor.isMutable()) {
simpAccessor.setDestination(destination);
simpAccessor.setMessageTypeIfNotSet(SimpMessageType.MESSAGE);
simpAccessor.setImmutable();
// ImmutableMessageChannelInterceptor will make it immutable
sendInternal(message);
return;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,16 @@ else if (channel instanceof ExecutorSubscribableChannel execChannel) {
}
}

/**
* Whether the channel has been {@link #configureInterceptor configured}
* with an interceptor for sequential handling.
* @since 6.1
*/
public static boolean supportsOrderedMessages(MessageChannel channel) {
return (channel instanceof ExecutorSubscribableChannel ch &&
ch.getInterceptors().stream().anyMatch(CallbackTaskInterceptor.class::isInstance));
}

/**
* Obtain the task to release the next message, if found.
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -131,16 +131,17 @@ public UserDestinationResult resolveDestination(Message<?> message) {
}
String user = parseResult.getUser();
String sourceDest = parseResult.getSourceDestination();
Set<String> sessionIds = parseResult.getSessionIds();
Set<String> targetSet = new HashSet<>();
for (String sessionId : parseResult.getSessionIds()) {
for (String sessionId : sessionIds) {
String actualDest = parseResult.getActualDestination();
String targetDest = getTargetDestination(sourceDest, actualDest, sessionId, user);
if (targetDest != null) {
targetSet.add(targetDest);
}
}
String subscribeDest = parseResult.getSubscribeDestination();
return new UserDestinationResult(sourceDest, targetSet, subscribeDest, user);
return new UserDestinationResult(sourceDest, targetSet, subscribeDest, user, sessionIds);
}

@Nullable
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright 2002-2021 the original author or authors.
* Copyright 2002-2023 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -17,13 +17,18 @@
package org.springframework.messaging.simp.user;

import java.util.Arrays;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;

import org.apache.commons.logging.Log;

import org.springframework.context.SmartLifecycle;
import org.springframework.lang.Nullable;
import org.springframework.messaging.Message;
import org.springframework.messaging.MessageChannel;
import org.springframework.messaging.MessageHandler;
import org.springframework.messaging.MessageHeaders;
import org.springframework.messaging.MessagingException;
Expand All @@ -33,6 +38,7 @@
import org.springframework.messaging.simp.SimpMessageHeaderAccessor;
import org.springframework.messaging.simp.SimpMessageType;
import org.springframework.messaging.simp.SimpMessagingTemplate;
import org.springframework.messaging.simp.broker.OrderedMessageChannelDecorator;
import org.springframework.messaging.support.MessageBuilder;
import org.springframework.messaging.support.MessageHeaderAccessor;
import org.springframework.messaging.support.MessageHeaderInitializer;
Expand Down Expand Up @@ -61,7 +67,7 @@ public class UserDestinationMessageHandler implements MessageHandler, SmartLifec

private final UserDestinationResolver destinationResolver;

private final MessageSendingOperations<String> messagingTemplate;
private final SendHelper sendHelper;

@Nullable
private BroadcastHandler broadcastHandler;
Expand Down Expand Up @@ -91,7 +97,7 @@ public UserDestinationMessageHandler(

this.clientInboundChannel = clientInboundChannel;
this.brokerChannel = brokerChannel;
this.messagingTemplate = new SimpMessagingTemplate(brokerChannel);
this.sendHelper = new SendHelper(clientInboundChannel, brokerChannel);
this.destinationResolver = destinationResolver;
}

Expand All @@ -112,7 +118,7 @@ public UserDestinationResolver getUserDestinationResolver() {
*/
public void setBroadcastDestination(@Nullable String destination) {
this.broadcastHandler = (StringUtils.hasText(destination) ?
new BroadcastHandler(this.messagingTemplate, destination) : null);
new BroadcastHandler(this.sendHelper.getMessagingTemplate(), destination) : null);
}

/**
Expand All @@ -128,7 +134,7 @@ public String getBroadcastDestination() {
* broker channel.
*/
public MessageSendingOperations<String> getBrokerMessagingTemplate() {
return this.messagingTemplate;
return this.sendHelper.getMessagingTemplate();
}

/**
Expand Down Expand Up @@ -193,6 +199,7 @@ public void handleMessage(Message<?> sourceMessage) throws MessagingException {

UserDestinationResult result = this.destinationResolver.resolveDestination(message);
if (result == null) {
this.sendHelper.checkDisconnect(message);
return;
}

Expand All @@ -215,9 +222,8 @@ public void handleMessage(Message<?> sourceMessage) throws MessagingException {
if (logger.isTraceEnabled()) {
logger.trace("Translated " + result.getSourceDestination() + " -> " + result.getTargetDestinations());
}
for (String target : result.getTargetDestinations()) {
this.messagingTemplate.send(target, message);
}

this.sendHelper.send(result, message);
}

private void initHeaders(SimpMessageHeaderAccessor headerAccessor) {
Expand All @@ -232,6 +238,63 @@ public String toString() {
}


private static class SendHelper {

private final MessageChannel brokerChannel;

private final MessageSendingOperations<String> messagingTemplate;

@Nullable
private final Map<String, MessageSendingOperations<String>> orderedMessagingTemplates;

SendHelper(MessageChannel clientInboundChannel, MessageChannel brokerChannel) {
this.brokerChannel = brokerChannel;
this.messagingTemplate = new SimpMessagingTemplate(brokerChannel);
if (OrderedMessageChannelDecorator.supportsOrderedMessages(clientInboundChannel)) {
this.orderedMessagingTemplates = new ConcurrentHashMap<>();
OrderedMessageChannelDecorator.configureInterceptor(brokerChannel, true);
}
else {
this.orderedMessagingTemplates = null;
}
}

public MessageSendingOperations<String> getMessagingTemplate() {
return this.messagingTemplate;
}

public void send(UserDestinationResult destinationResult, Message<?> message) throws MessagingException {
Set<String> sessionIds = destinationResult.getSessionIds();
Iterator<String> itr = (sessionIds != null ? sessionIds.iterator() : null);

for (String target : destinationResult.getTargetDestinations()) {
String sessionId = (itr != null ? itr.next() : null);
getTemplateToUse(sessionId).send(target, message);
}
}

private MessageSendingOperations<String> getTemplateToUse(@Nullable String sessionId) {
if (this.orderedMessagingTemplates != null && sessionId != null) {
return this.orderedMessagingTemplates.computeIfAbsent(sessionId, id ->
new SimpMessagingTemplate(new OrderedMessageChannelDecorator(this.brokerChannel, logger)));
}
return this.messagingTemplate;
}

public void checkDisconnect(Message<?> message) {
if (this.orderedMessagingTemplates != null) {
MessageHeaders headers = message.getHeaders();
if (SimpMessageHeaderAccessor.getMessageType(headers) == SimpMessageType.DISCONNECT) {
String sessionId = SimpMessageHeaderAccessor.getSessionId(headers);
if (sessionId != null) {
this.orderedMessagingTemplates.remove(sessionId);
}
}
}
}
}


/**
* A handler that broadcasts locally unresolved messages to the broker and
* also handles similar broadcasts received from the broker.
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright 2002-2017 the original author or authors.
* Copyright 2002-2023 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -16,6 +16,7 @@

package org.springframework.messaging.simp.user;

import java.util.Collections;
import java.util.Set;

import org.springframework.lang.Nullable;
Expand All @@ -40,10 +41,23 @@ public class UserDestinationResult {
@Nullable
private final String user;

private final Set<String> sessionIds;


public UserDestinationResult(String sourceDestination, Set<String> targetDestinations,
String subscribeDestination, @Nullable String user) {

this(sourceDestination, targetDestinations, subscribeDestination, user, null);
}

/**
* Additional constructor with the session id for each targetDestination.
* @since 6.1
*/
public UserDestinationResult(
String sourceDestination, Set<String> targetDestinations,
String subscribeDestination, @Nullable String user, @Nullable Set<String> sessionIds) {

Assert.notNull(sourceDestination, "'sourceDestination' must not be null");
Assert.notNull(targetDestinations, "'targetDestinations' must not be null");
Assert.notNull(subscribeDestination, "'subscribeDestination' must not be null");
Expand All @@ -52,6 +66,7 @@ public UserDestinationResult(String sourceDestination, Set<String> targetDestina
this.targetDestinations = targetDestinations;
this.subscribeDestination = subscribeDestination;
this.user = user;
this.sessionIds = (sessionIds != null ? sessionIds : Collections.emptySet());
}


Expand Down Expand Up @@ -96,6 +111,13 @@ public String getUser() {
return this.user;
}

/**
* Return the session id for the targetDestination.
*/
@Nullable
public Set<String> getSessionIds() {
return this.sessionIds;
}

@Override
public String toString() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,6 @@ public void convertAndSendWithMutableSimpMessageHeaders() {
Message<byte[]> message = messages.get(0);

assertThat(message.getHeaders()).isSameAs(headers);
assertThat(accessor.isMutable()).isFalse();
}

@Test
Expand Down Expand Up @@ -190,7 +189,6 @@ public void doSendWithMutableHeaders() {
Message<byte[]> sentMessage = messages.get(0);

assertThat(sentMessage).isSameAs(message);
assertThat(accessor.isMutable()).isFalse();
}

@Test
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import java.util.concurrent.CompletableFuture;
import java.util.stream.Stream;

import jakarta.servlet.Filter;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.junit.jupiter.api.AfterEach;
Expand All @@ -35,6 +36,7 @@
import org.springframework.context.Lifecycle;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.lang.Nullable;
import org.springframework.web.context.support.AnnotationConfigWebApplicationContext;
import org.springframework.web.socket.client.WebSocketClient;
import org.springframework.web.socket.client.standard.StandardWebSocketClient;
Expand Down Expand Up @@ -85,11 +87,18 @@ static Stream<Arguments> argumentsFactory() {
protected AnnotationConfigWebApplicationContext wac;


protected void setup(WebSocketTestServer server, WebSocketClient webSocketClient, TestInfo testInfo) throws Exception {
protected void setup(WebSocketTestServer server, WebSocketClient client, TestInfo info) throws Exception {
setup(server, null, client, info);
}

protected void setup(
WebSocketTestServer server, @Nullable Filter filter, WebSocketClient client, TestInfo info)
throws Exception {

this.server = server;
this.webSocketClient = webSocketClient;
this.webSocketClient = client;

logger.debug("Setting up '" + testInfo.getTestMethod().get().getName() + "', client=" +
logger.debug("Setting up '" + info.getTestMethod().get().getName() + "', client=" +
this.webSocketClient.getClass().getSimpleName() + ", server=" +
this.server.getClass().getSimpleName());

Expand All @@ -102,7 +111,12 @@ protected void setup(WebSocketTestServer server, WebSocketClient webSocketClient
}

this.server.setup();
this.server.deployConfig(this.wac);
if (filter != null) {
this.server.deployConfig(this.wac, filter);
}
else {
this.server.deployConfig(this.wac);
}
this.server.start();

this.wac.setServletContext(this.server.getServletContext());
Expand Down

0 comments on commit 3277b0d

Please sign in to comment.