getMessage() {
+ return this.message;
+ }
+
+
+ @Override
+ public String toString() {
+ return "SessionConnectedEvent: message=" + message;
+ }
+}
diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/messaging/SessionDisconnectEvent.java b/spring-websocket/src/main/java/org/springframework/web/socket/messaging/SessionDisconnectEvent.java
new file mode 100644
index 000000000000..4f2911155788
--- /dev/null
+++ b/spring-websocket/src/main/java/org/springframework/web/socket/messaging/SessionDisconnectEvent.java
@@ -0,0 +1,74 @@
+/*
+ * Copyright 2002-2014 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.springframework.web.socket.messaging;
+
+
+import org.springframework.context.ApplicationEvent;
+import org.springframework.messaging.Message;
+import org.springframework.util.Assert;
+import org.springframework.web.socket.CloseStatus;
+
+/**
+ * Event raised when the session of a WebSocket client using a Simple Messaging
+ * Protocol (e.g. STOMP) as the WebSocket sub-protocol is closed.
+ *
+ * Note that this event may be raised more than once for a single session and
+ * therefore event consumers should be idempotent and ignore a duplicate event..
+ *
+ * @author Rossen Stoyanchev
+ * @since 4.0.3
+ */
+@SuppressWarnings("serial")
+public class SessionDisconnectEvent extends ApplicationEvent {
+
+ private final String sessionId;
+
+ private final CloseStatus status;
+
+ /**
+ * Create a new event.
+ *
+ * @param source the component that published the event (never {@code null})
+ * @param sessionId the disconnect message
+ * @param closeStatus
+ */
+ public SessionDisconnectEvent(Object source, String sessionId, CloseStatus closeStatus) {
+ super(source);
+ Assert.notNull(sessionId, "'sessionId' must not be null");
+ this.sessionId = sessionId;
+ this.status = closeStatus;
+ }
+
+ /**
+ * Return the session id.
+ */
+ public String getSessionId() {
+ return this.sessionId;
+ }
+
+ /**
+ * Return the status with which the session was closed.
+ */
+ public CloseStatus getCloseStatus() {
+ return this.status;
+ }
+
+ @Override
+ public String toString() {
+ return "SessionDisconnectEvent: sessionId=" + this.sessionId;
+ }
+}
diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/messaging/StompSubProtocolHandler.java b/spring-websocket/src/main/java/org/springframework/web/socket/messaging/StompSubProtocolHandler.java
index 19ca5ba901d7..2996d819ad19 100644
--- a/spring-websocket/src/main/java/org/springframework/web/socket/messaging/StompSubProtocolHandler.java
+++ b/spring-websocket/src/main/java/org/springframework/web/socket/messaging/StompSubProtocolHandler.java
@@ -28,6 +28,8 @@
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
+import org.springframework.context.ApplicationEventPublisher;
+import org.springframework.context.ApplicationEventPublisherAware;
import org.springframework.messaging.Message;
import org.springframework.messaging.MessageChannel;
import org.springframework.messaging.simp.SimpMessageType;
@@ -57,7 +59,7 @@
* @author Andy Wilkinson
* @since 4.0
*/
-public class StompSubProtocolHandler implements SubProtocolHandler {
+public class StompSubProtocolHandler implements SubProtocolHandler, ApplicationEventPublisherAware {
/**
* The name of the header set on the CONNECTED frame indicating the name
@@ -76,6 +78,8 @@ public class StompSubProtocolHandler implements SubProtocolHandler {
private UserSessionRegistry userSessionRegistry;
+ private ApplicationEventPublisher eventPublisher;
+
/**
* Configure the maximum size allowed for an incoming STOMP message.
@@ -120,6 +124,12 @@ public List getSupportedProtocols() {
return Arrays.asList("v10.stomp", "v11.stomp", "v12.stomp");
}
+ @Override
+ public void setApplicationEventPublisher(ApplicationEventPublisher applicationEventPublisher) {
+ this.eventPublisher = applicationEventPublisher;
+ }
+
+
/**
* Handle incoming WebSocket messages from clients.
*/
@@ -167,6 +177,11 @@ public void handleMessageFromClient(WebSocketSession session,
headers.setUser(session.getPrincipal());
message = MessageBuilder.withPayload(message.getPayload()).setHeaders(headers).build();
+
+ if (SimpMessageType.CONNECT.equals(headers.getMessageType()) && this.eventPublisher != null) {
+ this.eventPublisher.publishEvent(new SessionConnectEvent(this, message));
+ }
+
outputChannel.send(message);
}
catch (Throwable ex) {
@@ -231,6 +246,11 @@ else if (SimpMessageType.MESSAGE.equals(headers.getMessageType())) {
try {
message = MessageBuilder.withPayload(message.getPayload()).setHeaders(headers).build();
+
+ if (headers.getCommand() == StompCommand.CONNECTED && this.eventPublisher != null) {
+ this.eventPublisher.publishEvent(new SessionConnectedEvent(this, (Message) message));
+ }
+
byte[] bytes = this.stompEncoder.encode((Message) message);
TextMessage textMessage = new TextMessage(bytes);
@@ -329,6 +349,11 @@ public void afterSessionEnded(WebSocketSession session, CloseStatus closeStatus,
StompHeaderAccessor headers = StompHeaderAccessor.create(StompCommand.DISCONNECT);
headers.setSessionId(session.getId());
Message> message = MessageBuilder.withPayload(new byte[0]).setHeaders(headers).build();
+
+ if (this.eventPublisher != null) {
+ this.eventPublisher.publishEvent(new SessionDisconnectEvent(this, session.getId(), closeStatus));
+ }
+
outputChannel.send(message);
}
diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/messaging/SubProtocolWebSocketHandler.java b/spring-websocket/src/main/java/org/springframework/web/socket/messaging/SubProtocolWebSocketHandler.java
index c51d38e543b8..a60d63299f28 100644
--- a/spring-websocket/src/main/java/org/springframework/web/socket/messaging/SubProtocolWebSocketHandler.java
+++ b/spring-websocket/src/main/java/org/springframework/web/socket/messaging/SubProtocolWebSocketHandler.java
@@ -28,6 +28,8 @@
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
+import org.springframework.context.ApplicationEventPublisher;
+import org.springframework.context.ApplicationEventPublisherAware;
import org.springframework.context.SmartLifecycle;
import org.springframework.messaging.Message;
import org.springframework.messaging.MessageChannel;
@@ -60,8 +62,8 @@
* @author Andy Wilkinson
* @since 4.0
*/
-public class SubProtocolWebSocketHandler
- implements WebSocketHandler, SubProtocolCapable, MessageHandler, SmartLifecycle {
+public class SubProtocolWebSocketHandler implements WebSocketHandler,
+ SubProtocolCapable, MessageHandler, SmartLifecycle, ApplicationEventPublisherAware {
private final Log logger = LogFactory.getLog(SubProtocolWebSocketHandler.class);
@@ -84,6 +86,8 @@ public class SubProtocolWebSocketHandler
private volatile boolean running = false;
+ private ApplicationEventPublisher eventPublisher;
+
public SubProtocolWebSocketHandler(MessageChannel clientInboundChannel, SubscribableChannel clientOutboundChannel) {
Assert.notNull(clientInboundChannel, "ClientInboundChannel must not be null");
@@ -114,11 +118,13 @@ public List getProtocolHandlers() {
* Register a sub-protocol handler.
*/
public void addProtocolHandler(SubProtocolHandler handler) {
+
List protocols = handler.getSupportedProtocols();
if (CollectionUtils.isEmpty(protocols)) {
logger.warn("No sub-protocols, ignoring handler " + handler);
return;
}
+
for (String protocol: protocols) {
SubProtocolHandler replaced = this.protocolHandlers.put(protocol, handler);
if ((replaced != null) && (replaced != handler) ) {
@@ -126,6 +132,10 @@ public void addProtocolHandler(SubProtocolHandler handler) {
+ " to protocol '" + protocol + "', it is already mapped to handler " + replaced);
}
}
+
+ if (handler instanceof ApplicationEventPublisherAware) {
+ ((ApplicationEventPublisherAware) handler).setApplicationEventPublisher(this.eventPublisher);
+ }
}
/**
@@ -178,6 +188,10 @@ public int getSendBufferSizeLimit() {
return sendBufferSizeLimit;
}
+ @Override
+ public void setApplicationEventPublisher(ApplicationEventPublisher eventPublisher) {
+ this.eventPublisher = eventPublisher;
+ }
@Override
public boolean isAutoStartup() {
diff --git a/spring-websocket/src/test/java/org/springframework/web/socket/messaging/StompSubProtocolHandlerTests.java b/spring-websocket/src/test/java/org/springframework/web/socket/messaging/StompSubProtocolHandlerTests.java
index 9f0a292541fc..89c10ce07344 100644
--- a/spring-websocket/src/test/java/org/springframework/web/socket/messaging/StompSubProtocolHandlerTests.java
+++ b/spring-websocket/src/test/java/org/springframework/web/socket/messaging/StompSubProtocolHandlerTests.java
@@ -17,6 +17,7 @@
package org.springframework.web.socket.messaging;
import java.nio.ByteBuffer;
+import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashSet;
@@ -26,6 +27,8 @@
import org.junit.Test;
import org.mockito.ArgumentCaptor;
import org.mockito.Mockito;
+import org.springframework.context.ApplicationEvent;
+import org.springframework.context.ApplicationEventPublisher;
import org.springframework.messaging.Message;
import org.springframework.messaging.MessageChannel;
import org.springframework.messaging.simp.SimpMessageHeaderAccessor;
@@ -33,17 +36,18 @@
import org.springframework.messaging.simp.TestPrincipal;
import org.springframework.messaging.simp.stomp.StompCommand;
import org.springframework.messaging.simp.stomp.StompDecoder;
+import org.springframework.messaging.simp.stomp.StompEncoder;
import org.springframework.messaging.simp.stomp.StompHeaderAccessor;
import org.springframework.messaging.simp.user.DefaultUserSessionRegistry;
import org.springframework.messaging.simp.user.DestinationUserNameProvider;
import org.springframework.messaging.simp.user.UserDestinationMessageHandler;
import org.springframework.messaging.simp.user.UserSessionRegistry;
import org.springframework.messaging.support.MessageBuilder;
+import org.springframework.web.socket.CloseStatus;
import org.springframework.web.socket.TextMessage;
import org.springframework.web.socket.WebSocketMessage;
import org.springframework.web.socket.handler.TestWebSocketSession;
import org.springframework.web.socket.sockjs.transport.SockJsSession;
-import org.springframework.web.socket.sockjs.transport.session.TestSockJsSession;
import static org.junit.Assert.*;
import static org.mockito.Mockito.*;
@@ -157,6 +161,33 @@ public void handleMessageToClientConnectAck() {
assertEquals("joe", replyHeaders.getNativeHeader("user-name").get(0));
}
+ @Test
+ public void eventPublication() {
+
+ TestPublisher publisher = new TestPublisher();
+
+ UserSessionRegistry registry = new DefaultUserSessionRegistry();
+ this.protocolHandler.setUserSessionRegistry(registry);
+ this.protocolHandler.setApplicationEventPublisher(publisher);
+ this.protocolHandler.afterSessionStarted(this.session, this.channel);
+
+ StompHeaderAccessor headers = StompHeaderAccessor.create(StompCommand.CONNECT);
+ TextMessage textMessage = new TextMessage(new StompEncoder().encode(
+ MessageBuilder.withPayload(new byte[0]).setHeaders(headers).build()));
+ this.protocolHandler.handleMessageFromClient(this.session, textMessage, this.channel);
+
+ headers = StompHeaderAccessor.create(StompCommand.CONNECTED);
+ Message connectedMessage = MessageBuilder.withPayload(new byte[0]).setHeaders(headers).build();
+ this.protocolHandler.handleMessageToClient(this.session, connectedMessage);
+
+ this.protocolHandler.afterSessionEnded(this.session, CloseStatus.BAD_DATA, this.channel);
+
+ assertEquals(3, publisher.events.size());
+ assertEquals(SessionConnectEvent.class, publisher.events.get(0).getClass());
+ assertEquals(SessionConnectedEvent.class, publisher.events.get(1).getClass());
+ assertEquals(SessionDisconnectEvent.class, publisher.events.get(2).getClass());
+ }
+
@Test
public void handleMessageToClientUserDestination() {
@@ -225,4 +256,14 @@ public String getDestinationUserName() {
}
}
+ private static class TestPublisher implements ApplicationEventPublisher {
+
+ private final List events = new ArrayList();
+
+ @Override
+ public void publishEvent(ApplicationEvent event) {
+ events.add(event);
+ }
+ }
+
}