Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -51,11 +51,13 @@
import org.springframework.messaging.simp.broker.SimpleBrokerMessageHandler;
import org.springframework.messaging.simp.stomp.StompBrokerRelayMessageHandler;
import org.springframework.messaging.simp.stomp.StompCommand;
import org.springframework.messaging.simp.stomp.StompEncoder;
import org.springframework.messaging.simp.stomp.StompHeaderAccessor;
import org.springframework.messaging.support.MessageBuilder;
import org.springframework.util.Assert;
import org.springframework.util.CollectionUtils;
import org.springframework.util.MimeTypeUtils;
import org.springframework.web.socket.BinaryMessage;
import org.springframework.web.socket.CloseStatus;
import org.springframework.web.socket.WebSocketMessage;
import org.springframework.web.socket.WebSocketSession;
Expand Down Expand Up @@ -141,7 +143,7 @@ public WebSocketInboundChannelAdapter(IntegrationWebSocketContainer webSocketCon
}

/**
* Set the message converters to use. These converters are used to convert the message to send for appropriate
* Set the message converters to use. These converters are used to convert the message to send for the appropriate
* internal subProtocols type.
* @param messageConverters The message converters.
*/
Expand All @@ -160,7 +162,7 @@ public void setMergeWithDefaultConverters(boolean mergeWithDefaultConverters) {
}

/**
* Set the type for target message payload to convert the WebSocket message body to.
* Set the type for the target message payload to convert the WebSocket message body to.
* @param payloadType to convert inbound WebSocket message body
* @see CompositeMessageConverter
*/
Expand All @@ -174,9 +176,9 @@ public void setPayloadType(Class<?> payloadType) {
* bean for {@code non-MESSAGE} {@link org.springframework.web.socket.WebSocketMessage}s
* and to route messages with broker destinations.
* Since only single {@link AbstractBrokerMessageHandler} bean is allowed in the current
* application context, the algorithm to lookup the former by type, rather than applying
* application context, the algorithm is to look up the former by type, rather than applying
* the bean reference.
* This is used only on server side and is ignored from client side.
* This is used only on the server side and is ignored from the client side.
* @param useBroker the boolean flag.
*/
public void setUseBroker(boolean useBroker) {
Expand Down Expand Up @@ -234,13 +236,23 @@ public void afterSessionStarted(WebSocketSession session) {
SubProtocolHandler protocolHandler = this.subProtocolHandlerRegistry.findProtocolHandler(session);
protocolHandler.afterSessionStarted(session, this.subProtocolHandlerChannel);
if (!this.server && protocolHandler instanceof StompSubProtocolHandler) {
// The CONNECT frame is required by the STOMP specification.
StompHeaderAccessor accessor = StompHeaderAccessor.create(StompCommand.CONNECT);
accessor.setSessionId(session.getId());
accessor.setLeaveMutable(true);
accessor.setAcceptVersion("1.1,1.2");

Message<?> connectMessage =
Message<byte[]> connectMessage =
MessageBuilder.createMessage(EMPTY_PAYLOAD, accessor.getMessageHeaders());

// In the client mode, the client session has to register itself
// into the StompSubProtocolHandler cache
// for proper correlation of the messages from the server side.
StompEncoder stompEncoder = new StompEncoder();
byte[] connectMessageBytes = stompEncoder.encode(connectMessage);
protocolHandler.handleMessageFromClient(session, new BinaryMessage(connectMessageBytes),
this.subProtocolHandlerChannel);

protocolHandler.handleMessageToClient(session, connectMessage);
}
}
Expand Down Expand Up @@ -313,7 +325,11 @@ private void handleMessageAndSend(final Message<?> message) {
SimpMessageType messageType = headerAccessor.getMessageType();
if (isProcessingTypeOrCommand(headerAccessor, stompCommand, messageType)) {
if (SimpMessageType.CONNECT.equals(messageType)) {
produceConnectAckMessage(message, headerAccessor);
// Ignore the CONNECT frame in the client mode.
// Essentially, it has been just initiated from the {@link #afterSessionStarted}.
if (this.server) {
produceConnectAckMessage(message, headerAccessor);
}
}
else if (StompCommand.CONNECTED.equals(stompCommand)) {
this.eventPublisher.publishEvent(new SessionConnectedEvent(this, (Message<byte[]>) message));
Expand All @@ -338,10 +354,10 @@ else if (StompCommand.RECEIPT.equals(stompCommand)) {
}
}

private boolean isProcessingTypeOrCommand(SimpMessageHeaderAccessor headerAccessor, @Nullable StompCommand stompCommand,
@Nullable SimpMessageType messageType) {
private boolean isProcessingTypeOrCommand(SimpMessageHeaderAccessor headerAccessor,
@Nullable StompCommand stompCommand, @Nullable SimpMessageType messageType) {

return (messageType == null // NOSONAR pretty simple logic
return (messageType == null
|| SimpMessageType.MESSAGE.equals(messageType)
|| (SimpMessageType.CONNECT.equals(messageType) && !this.useBroker)
|| StompCommand.CONNECTED.equals(stompCommand)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeUnit;

import org.junit.jupiter.api.Disabled;
import org.junit.jupiter.api.Test;

import org.springframework.beans.factory.annotation.Autowired;
Expand Down Expand Up @@ -85,6 +84,7 @@
import org.springframework.web.socket.config.annotation.StompEndpointRegistry;
import org.springframework.web.socket.config.annotation.WebSocketMessageBrokerConfigurer;
import org.springframework.web.socket.messaging.AbstractSubProtocolEvent;
import org.springframework.web.socket.messaging.SessionConnectEvent;
import org.springframework.web.socket.messaging.SessionConnectedEvent;
import org.springframework.web.socket.messaging.SessionSubscribeEvent;
import org.springframework.web.socket.messaging.StompSubProtocolHandler;
Expand All @@ -95,6 +95,7 @@
import org.springframework.web.socket.sockjs.client.WebSocketTransport;

import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.InstanceOfAssertFactories.type;

/**
* @author Artem Bilan
Expand All @@ -103,7 +104,6 @@
*/
@SpringJUnitConfig(classes = StompIntegrationTests.ClientConfig.class)
@DirtiesContext(classMode = DirtiesContext.ClassMode.AFTER_EACH_TEST_METHOD)
@Disabled("TODO until the lastest fix from SF mitigation")
public class StompIntegrationTests {

@Value("#{server.serverContext}")
Expand All @@ -126,35 +126,44 @@ public class StompIntegrationTests {

@Test
public void sendMessageToController() throws Exception {
StompHeaderAccessor headers = StompHeaderAccessor.create(StompCommand.CONNECT);
this.webSocketOutputChannel.send(MessageBuilder.withPayload(new byte[0]).setHeaders(headers).build());

Message<?> receive = this.webSocketEvents.receive(20000);
assertThat(receive).isNotNull();
Object event = receive.getPayload();
assertThat(event).isInstanceOf(SessionConnectedEvent.class);
Message<?> connectedMessage = ((SessionConnectedEvent) event).getMessage();
headers = StompHeaderAccessor.wrap(connectedMessage);
assertThat(headers.getCommand()).isEqualTo(StompCommand.CONNECTED);
assertThat(receive)
.extracting(Message::getPayload)
// We've just registered our own connected client session from the WebSocketInboundChannelAdapter
.isInstanceOf(SessionConnectEvent.class);

headers = StompHeaderAccessor.create(StompCommand.SEND);
receive = this.webSocketEvents.receive(20000);
assertThat(receive)
.extracting(Message::getPayload)
.asInstanceOf(type(SessionConnectedEvent.class))
.extracting(SessionConnectedEvent::getMessage)
.extracting(connectedMessage -> StompHeaderAccessor.wrap(connectedMessage).getCommand())
.isEqualTo(StompCommand.CONNECTED);

StompHeaderAccessor headers = StompHeaderAccessor.create(StompCommand.SEND);
headers.setSubscriptionId("sub1");
headers.setDestination("/app/simple");
Message<String> message = MessageBuilder.withPayload("foo").setHeaders(headers).build();

this.webSocketOutputChannel.send(message);

SimpleController controller = this.serverContext.getBean(SimpleController.class);
assertThat(controller.latch.await(20, TimeUnit.SECONDS)).isTrue();
assertThat(controller.latch.await(10, TimeUnit.SECONDS)).isTrue();
assertThat(controller.stompCommand).isEqualTo(StompCommand.SEND.name());
}

@Test
public void sendMessageToControllerAndReceiveReplyViaTopic() throws Exception {
Message<?> receive = this.webSocketEvents.receive(20000);
assertThat(receive).isNotNull();
Object event = receive.getPayload();
assertThat(event).isInstanceOf(SessionConnectedEvent.class);
assertThat(receive)
.extracting(Message::getPayload)
// We've just registered our own connected client session from the WebSocketInboundChannelAdapter
.isInstanceOf(SessionConnectEvent.class);

receive = this.webSocketEvents.receive(20000);
assertThat(receive)
.extracting(Message::getPayload)
.isInstanceOf(SessionConnectedEvent.class);

StompHeaderAccessor headers = StompHeaderAccessor.create(StompCommand.SUBSCRIBE);
headers.setSubscriptionId("subs1");
Expand All @@ -167,13 +176,14 @@ public void sendMessageToControllerAndReceiveReplyViaTopic() throws Exception {
this.webSocketOutputChannel.send(message);

receive = this.webSocketEvents.receive(20000);
assertThat(receive).isNotNull();
event = receive.getPayload();
assertThat(event).isInstanceOf(ReceiptEvent.class);
Message<?> receiptMessage = ((ReceiptEvent) event).getMessage();
headers = StompHeaderAccessor.wrap(receiptMessage);
assertThat(headers.getCommand()).isEqualTo(StompCommand.RECEIPT);
assertThat(headers.getReceiptId()).isEqualTo("myReceipt");
assertThat(receive)
.extracting(Message::getPayload)
.asInstanceOf(type(ReceiptEvent.class))
.extracting(event -> StompHeaderAccessor.wrap(event.getMessage()))
.satisfies(headerAccessor -> {
assertThat(headerAccessor.getCommand()).isEqualTo(StompCommand.RECEIPT);
assertThat(headerAccessor.getReceiptId()).isEqualTo("myReceipt");
});

waitForSubscribe("/topic/increment");

Expand Down Expand Up @@ -494,7 +504,7 @@ public void configureMessageBroker(MessageBrokerRegistry configurer) {
public ApplicationListener<SessionSubscribeEvent> webSocketEventListener(
final AbstractSubscribableChannel clientOutboundChannel) {
// Cannot be lambda because Java can't infer generic type from lambdas,
// therefore we end up with ClassCastException for other event types
// therefore, we end up with ClassCastException for other event types
return new ApplicationListener<SessionSubscribeEvent>() {

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
import java.util.Map;

import org.apache.tomcat.websocket.Constants;
import org.junit.jupiter.api.Disabled;
import org.junit.jupiter.api.Test;

import org.springframework.beans.factory.annotation.Autowired;
Expand Down Expand Up @@ -68,7 +67,6 @@
*/
@SpringJUnitConfig(classes = WebSocketClientTests.ClientConfig.class)
@DirtiesContext
@Disabled("TODO until the lastest fix from SF mitigation")
public class WebSocketClientTests {

@Autowired
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
import java.util.Collections;
import java.util.Map;

import org.junit.jupiter.api.Disabled;
import org.junit.jupiter.api.Test;

import org.springframework.beans.factory.annotation.Autowired;
Expand Down Expand Up @@ -64,7 +63,6 @@
*/
@SpringJUnitConfig
@DirtiesContext
@Disabled("TODO until the lastest fix from SF mitigation")
public class WebSocketInboundChannelAdapterTests {

@Value("#{server.serverContext.getBean('subProtocolWebSocketHandler')}")
Expand Down Expand Up @@ -98,6 +96,7 @@ public void testWebSocketInboundChannelAdapter() {
StompHeaderAccessor headers = StompHeaderAccessor.create(StompCommand.MESSAGE);
headers.setLeaveMutable(true);
headers.setSessionId(sessionId);
headers.setSubscriptionId("sub1");
Message<byte[]> message =
MessageBuilder.createMessage(ByteBuffer.allocate(0).array(), headers.getMessageHeaders());

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@

import java.util.Collections;

import org.junit.jupiter.api.Disabled;
import org.junit.jupiter.api.Test;

import org.springframework.beans.factory.annotation.Autowired;
Expand Down Expand Up @@ -56,7 +55,6 @@
*/
@SpringJUnitConfig
@DirtiesContext
@Disabled("TODO until the lastest fix from SF mitigation")
public class WebSocketOutboundMessageHandlerTests {

@Autowired
Expand All @@ -68,22 +66,32 @@ public class WebSocketOutboundMessageHandlerTests {

@Test
public void testWebSocketOutboundMessageHandler() {
StompHeaderAccessor headers = StompHeaderAccessor.create(StompCommand.SEND);
StompHeaderAccessor headers = StompHeaderAccessor.create(StompCommand.CONNECT);
this.messageHandler.handleMessage(MessageBuilder.withPayload(new byte[0]).setHeaders(headers).build());

headers = StompHeaderAccessor.create(StompCommand.SEND);
headers.setMessageId("mess0");
headers.setSubscriptionId("sub0");
headers.setDestination("/foo");
headers.setDestination("/dest");
String payload = "Hello World";
Message<String> message = MessageBuilder.withPayload(payload).setHeaders(headers).build();

this.messageHandler.handleMessage(message);

Message<?> received = this.clientInboundChannel.receive(10000);
assertThat(received).isNotNull();

StompHeaderAccessor receivedHeaders = StompHeaderAccessor.wrap(received);
assertThat(receivedHeaders.getMessageId()).isEqualTo("mess0");
assertThat(receivedHeaders.getSubscriptionId()).isEqualTo("sub0");
assertThat(receivedHeaders.getDestination()).isEqualTo("/foo");
assertThat(received)
.extracting(StompHeaderAccessor::wrap)
.extracting(StompHeaderAccessor::getCommand)
.isEqualTo(StompCommand.CONNECT);

received = this.clientInboundChannel.receive(10000);
assertThat(received)
.extracting(StompHeaderAccessor::wrap)
.satisfies(headerAccessor -> {
assertThat(headerAccessor.getMessageId()).isEqualTo("mess0");
assertThat(headerAccessor.getSubscriptionId()).isEqualTo("sub0");
assertThat(headerAccessor.getDestination()).isEqualTo("/dest");
});

Object receivedPayload = received.getPayload();
assertThat(receivedPayload).isInstanceOf(byte[].class);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
import java.util.Collections;
import java.util.List;

import org.junit.jupiter.api.Disabled;
import org.junit.jupiter.api.Test;
import org.mockito.Mockito;

Expand Down Expand Up @@ -118,7 +117,6 @@ public class WebSocketServerTests {
private Lifecycle requestUpgradeStrategy;

@Test
@Disabled("TODO until the lastest fix from SF mitigation")
public void testWebSocketOutboundMessageHandler() {
StompHeaderAccessor headers = StompHeaderAccessor.create(StompCommand.SUBSCRIBE);
headers.setSubscriptionId("subs1");
Expand Down