diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/messaging/SessionConnectEvent.java b/spring-websocket/src/main/java/org/springframework/web/socket/messaging/SessionConnectEvent.java new file mode 100644 index 000000000000..348609b1a7e1 --- /dev/null +++ b/spring-websocket/src/main/java/org/springframework/web/socket/messaging/SessionConnectEvent.java @@ -0,0 +1,78 @@ +/* + * Copyright 2002-2014 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.socket.messaging; + + +import org.springframework.context.ApplicationEvent; +import org.springframework.messaging.Message; +import org.springframework.util.Assert; + +/** + * Event raised when a new WebSocket client using a Simple Messaging Protocol + * (e.g. STOMP) as the WebSocket sub-protocol issues a connect request. + * + *

Note that this is not the same as the WebSocket session getting established + * but rather the client's first attempt to connect within the the sub-protocol, + * for example sending the STOMP CONNECT frame. + * + *

The provided {@link #getMessage() message} can be examined to check + * information about the connected user, The session id, and any headers + * sent by the client, for STOMP check the class + * {@link org.springframework.messaging.simp.stomp.StompHeaderAccessor}. + * For example: + * + *

+ * StompHeaderAccessor headers = StompHeaderAccessor.wrap(message);
+ * headers.getSessionId();
+ * headers.getSessionAttributes();
+ * headers.getPrincipal();
+ * 
+ * + * @author Rossen Stoyanchev + * @since 4.0.3 + */ +@SuppressWarnings("serial") +public class SessionConnectEvent extends ApplicationEvent { + + private final Message message; + + + /** + * Create a new SessionConnectEvent. + * + * @param source the component that published the event (never {@code null}) + * @param message the connect message + */ + public SessionConnectEvent(Object source, Message message) { + super(source); + Assert.notNull(message, "'message' must not be null"); + this.message = message; + } + + /** + * Return the connect message. + */ + public Message getMessage() { + return this.message; + } + + + @Override + public String toString() { + return "SessionConnectEvent: message=" + message; + } +} diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/messaging/SessionConnectedEvent.java b/spring-websocket/src/main/java/org/springframework/web/socket/messaging/SessionConnectedEvent.java new file mode 100644 index 000000000000..45e47dce81e5 --- /dev/null +++ b/spring-websocket/src/main/java/org/springframework/web/socket/messaging/SessionConnectedEvent.java @@ -0,0 +1,61 @@ +/* + * Copyright 2002-2014 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.socket.messaging; + + +import org.springframework.context.ApplicationEvent; +import org.springframework.messaging.Message; +import org.springframework.util.Assert; + +/** + * A connected event represents the server response to a client's connect request. + * See {@link org.springframework.web.socket.messaging.SessionConnectEvent}. + * + * @author Rossen Stoyanchev + * @since 4.0.3 + */ +@SuppressWarnings("serial") +public class SessionConnectedEvent extends ApplicationEvent { + + private final Message message; + + + /** + * Create a new event. + * + * @param source the component that published the event (never {@code null}) + * @param message the connected message + */ + public SessionConnectedEvent(Object source, Message message) { + super(source); + Assert.notNull(message, "'message' must not be null"); + this.message = message; + } + + /** + * Return the connected message. + */ + public Message getMessage() { + return this.message; + } + + + @Override + public String toString() { + return "SessionConnectedEvent: message=" + message; + } +} diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/messaging/SessionDisconnectEvent.java b/spring-websocket/src/main/java/org/springframework/web/socket/messaging/SessionDisconnectEvent.java new file mode 100644 index 000000000000..4f2911155788 --- /dev/null +++ b/spring-websocket/src/main/java/org/springframework/web/socket/messaging/SessionDisconnectEvent.java @@ -0,0 +1,74 @@ +/* + * Copyright 2002-2014 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.socket.messaging; + + +import org.springframework.context.ApplicationEvent; +import org.springframework.messaging.Message; +import org.springframework.util.Assert; +import org.springframework.web.socket.CloseStatus; + +/** + * Event raised when the session of a WebSocket client using a Simple Messaging + * Protocol (e.g. STOMP) as the WebSocket sub-protocol is closed. + * + *

Note that this event may be raised more than once for a single session and + * therefore event consumers should be idempotent and ignore a duplicate event.. + * + * @author Rossen Stoyanchev + * @since 4.0.3 + */ +@SuppressWarnings("serial") +public class SessionDisconnectEvent extends ApplicationEvent { + + private final String sessionId; + + private final CloseStatus status; + + /** + * Create a new event. + * + * @param source the component that published the event (never {@code null}) + * @param sessionId the disconnect message + * @param closeStatus + */ + public SessionDisconnectEvent(Object source, String sessionId, CloseStatus closeStatus) { + super(source); + Assert.notNull(sessionId, "'sessionId' must not be null"); + this.sessionId = sessionId; + this.status = closeStatus; + } + + /** + * Return the session id. + */ + public String getSessionId() { + return this.sessionId; + } + + /** + * Return the status with which the session was closed. + */ + public CloseStatus getCloseStatus() { + return this.status; + } + + @Override + public String toString() { + return "SessionDisconnectEvent: sessionId=" + this.sessionId; + } +} diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/messaging/StompSubProtocolHandler.java b/spring-websocket/src/main/java/org/springframework/web/socket/messaging/StompSubProtocolHandler.java index 19ca5ba901d7..2996d819ad19 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/messaging/StompSubProtocolHandler.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/messaging/StompSubProtocolHandler.java @@ -28,6 +28,8 @@ import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; +import org.springframework.context.ApplicationEventPublisher; +import org.springframework.context.ApplicationEventPublisherAware; import org.springframework.messaging.Message; import org.springframework.messaging.MessageChannel; import org.springframework.messaging.simp.SimpMessageType; @@ -57,7 +59,7 @@ * @author Andy Wilkinson * @since 4.0 */ -public class StompSubProtocolHandler implements SubProtocolHandler { +public class StompSubProtocolHandler implements SubProtocolHandler, ApplicationEventPublisherAware { /** * The name of the header set on the CONNECTED frame indicating the name @@ -76,6 +78,8 @@ public class StompSubProtocolHandler implements SubProtocolHandler { private UserSessionRegistry userSessionRegistry; + private ApplicationEventPublisher eventPublisher; + /** * Configure the maximum size allowed for an incoming STOMP message. @@ -120,6 +124,12 @@ public List getSupportedProtocols() { return Arrays.asList("v10.stomp", "v11.stomp", "v12.stomp"); } + @Override + public void setApplicationEventPublisher(ApplicationEventPublisher applicationEventPublisher) { + this.eventPublisher = applicationEventPublisher; + } + + /** * Handle incoming WebSocket messages from clients. */ @@ -167,6 +177,11 @@ public void handleMessageFromClient(WebSocketSession session, headers.setUser(session.getPrincipal()); message = MessageBuilder.withPayload(message.getPayload()).setHeaders(headers).build(); + + if (SimpMessageType.CONNECT.equals(headers.getMessageType()) && this.eventPublisher != null) { + this.eventPublisher.publishEvent(new SessionConnectEvent(this, message)); + } + outputChannel.send(message); } catch (Throwable ex) { @@ -231,6 +246,11 @@ else if (SimpMessageType.MESSAGE.equals(headers.getMessageType())) { try { message = MessageBuilder.withPayload(message.getPayload()).setHeaders(headers).build(); + + if (headers.getCommand() == StompCommand.CONNECTED && this.eventPublisher != null) { + this.eventPublisher.publishEvent(new SessionConnectedEvent(this, (Message) message)); + } + byte[] bytes = this.stompEncoder.encode((Message) message); TextMessage textMessage = new TextMessage(bytes); @@ -329,6 +349,11 @@ public void afterSessionEnded(WebSocketSession session, CloseStatus closeStatus, StompHeaderAccessor headers = StompHeaderAccessor.create(StompCommand.DISCONNECT); headers.setSessionId(session.getId()); Message message = MessageBuilder.withPayload(new byte[0]).setHeaders(headers).build(); + + if (this.eventPublisher != null) { + this.eventPublisher.publishEvent(new SessionDisconnectEvent(this, session.getId(), closeStatus)); + } + outputChannel.send(message); } diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/messaging/SubProtocolWebSocketHandler.java b/spring-websocket/src/main/java/org/springframework/web/socket/messaging/SubProtocolWebSocketHandler.java index c51d38e543b8..a60d63299f28 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/messaging/SubProtocolWebSocketHandler.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/messaging/SubProtocolWebSocketHandler.java @@ -28,6 +28,8 @@ import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; +import org.springframework.context.ApplicationEventPublisher; +import org.springframework.context.ApplicationEventPublisherAware; import org.springframework.context.SmartLifecycle; import org.springframework.messaging.Message; import org.springframework.messaging.MessageChannel; @@ -60,8 +62,8 @@ * @author Andy Wilkinson * @since 4.0 */ -public class SubProtocolWebSocketHandler - implements WebSocketHandler, SubProtocolCapable, MessageHandler, SmartLifecycle { +public class SubProtocolWebSocketHandler implements WebSocketHandler, + SubProtocolCapable, MessageHandler, SmartLifecycle, ApplicationEventPublisherAware { private final Log logger = LogFactory.getLog(SubProtocolWebSocketHandler.class); @@ -84,6 +86,8 @@ public class SubProtocolWebSocketHandler private volatile boolean running = false; + private ApplicationEventPublisher eventPublisher; + public SubProtocolWebSocketHandler(MessageChannel clientInboundChannel, SubscribableChannel clientOutboundChannel) { Assert.notNull(clientInboundChannel, "ClientInboundChannel must not be null"); @@ -114,11 +118,13 @@ public List getProtocolHandlers() { * Register a sub-protocol handler. */ public void addProtocolHandler(SubProtocolHandler handler) { + List protocols = handler.getSupportedProtocols(); if (CollectionUtils.isEmpty(protocols)) { logger.warn("No sub-protocols, ignoring handler " + handler); return; } + for (String protocol: protocols) { SubProtocolHandler replaced = this.protocolHandlers.put(protocol, handler); if ((replaced != null) && (replaced != handler) ) { @@ -126,6 +132,10 @@ public void addProtocolHandler(SubProtocolHandler handler) { + " to protocol '" + protocol + "', it is already mapped to handler " + replaced); } } + + if (handler instanceof ApplicationEventPublisherAware) { + ((ApplicationEventPublisherAware) handler).setApplicationEventPublisher(this.eventPublisher); + } } /** @@ -178,6 +188,10 @@ public int getSendBufferSizeLimit() { return sendBufferSizeLimit; } + @Override + public void setApplicationEventPublisher(ApplicationEventPublisher eventPublisher) { + this.eventPublisher = eventPublisher; + } @Override public boolean isAutoStartup() { diff --git a/spring-websocket/src/test/java/org/springframework/web/socket/messaging/StompSubProtocolHandlerTests.java b/spring-websocket/src/test/java/org/springframework/web/socket/messaging/StompSubProtocolHandlerTests.java index 9f0a292541fc..89c10ce07344 100644 --- a/spring-websocket/src/test/java/org/springframework/web/socket/messaging/StompSubProtocolHandlerTests.java +++ b/spring-websocket/src/test/java/org/springframework/web/socket/messaging/StompSubProtocolHandlerTests.java @@ -17,6 +17,7 @@ package org.springframework.web.socket.messaging; import java.nio.ByteBuffer; +import java.util.ArrayList; import java.util.Arrays; import java.util.Collections; import java.util.HashSet; @@ -26,6 +27,8 @@ import org.junit.Test; import org.mockito.ArgumentCaptor; import org.mockito.Mockito; +import org.springframework.context.ApplicationEvent; +import org.springframework.context.ApplicationEventPublisher; import org.springframework.messaging.Message; import org.springframework.messaging.MessageChannel; import org.springframework.messaging.simp.SimpMessageHeaderAccessor; @@ -33,17 +36,18 @@ import org.springframework.messaging.simp.TestPrincipal; import org.springframework.messaging.simp.stomp.StompCommand; import org.springframework.messaging.simp.stomp.StompDecoder; +import org.springframework.messaging.simp.stomp.StompEncoder; import org.springframework.messaging.simp.stomp.StompHeaderAccessor; import org.springframework.messaging.simp.user.DefaultUserSessionRegistry; import org.springframework.messaging.simp.user.DestinationUserNameProvider; import org.springframework.messaging.simp.user.UserDestinationMessageHandler; import org.springframework.messaging.simp.user.UserSessionRegistry; import org.springframework.messaging.support.MessageBuilder; +import org.springframework.web.socket.CloseStatus; import org.springframework.web.socket.TextMessage; import org.springframework.web.socket.WebSocketMessage; import org.springframework.web.socket.handler.TestWebSocketSession; import org.springframework.web.socket.sockjs.transport.SockJsSession; -import org.springframework.web.socket.sockjs.transport.session.TestSockJsSession; import static org.junit.Assert.*; import static org.mockito.Mockito.*; @@ -157,6 +161,33 @@ public void handleMessageToClientConnectAck() { assertEquals("joe", replyHeaders.getNativeHeader("user-name").get(0)); } + @Test + public void eventPublication() { + + TestPublisher publisher = new TestPublisher(); + + UserSessionRegistry registry = new DefaultUserSessionRegistry(); + this.protocolHandler.setUserSessionRegistry(registry); + this.protocolHandler.setApplicationEventPublisher(publisher); + this.protocolHandler.afterSessionStarted(this.session, this.channel); + + StompHeaderAccessor headers = StompHeaderAccessor.create(StompCommand.CONNECT); + TextMessage textMessage = new TextMessage(new StompEncoder().encode( + MessageBuilder.withPayload(new byte[0]).setHeaders(headers).build())); + this.protocolHandler.handleMessageFromClient(this.session, textMessage, this.channel); + + headers = StompHeaderAccessor.create(StompCommand.CONNECTED); + Message connectedMessage = MessageBuilder.withPayload(new byte[0]).setHeaders(headers).build(); + this.protocolHandler.handleMessageToClient(this.session, connectedMessage); + + this.protocolHandler.afterSessionEnded(this.session, CloseStatus.BAD_DATA, this.channel); + + assertEquals(3, publisher.events.size()); + assertEquals(SessionConnectEvent.class, publisher.events.get(0).getClass()); + assertEquals(SessionConnectedEvent.class, publisher.events.get(1).getClass()); + assertEquals(SessionDisconnectEvent.class, publisher.events.get(2).getClass()); + } + @Test public void handleMessageToClientUserDestination() { @@ -225,4 +256,14 @@ public String getDestinationUserName() { } } + private static class TestPublisher implements ApplicationEventPublisher { + + private final List events = new ArrayList(); + + @Override + public void publishEvent(ApplicationEvent event) { + events.add(event); + } + } + }