Skip to content

Commit

Permalink
Add session lifecycle ApplicationEvent's
Browse files Browse the repository at this point in the history
Issue: SPR-11578
  • Loading branch information
rstoyanchev committed Mar 25, 2014
1 parent 0745907 commit 13da705
Show file tree
Hide file tree
Showing 6 changed files with 297 additions and 4 deletions.
@@ -0,0 +1,78 @@
/*
* Copyright 2002-2014 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.springframework.web.socket.messaging;


import org.springframework.context.ApplicationEvent;
import org.springframework.messaging.Message;
import org.springframework.util.Assert;

/**
* Event raised when a new WebSocket client using a Simple Messaging Protocol
* (e.g. STOMP) as the WebSocket sub-protocol issues a connect request.
*
* <p>Note that this is not the same as the WebSocket session getting established
* but rather the client's first attempt to connect within the the sub-protocol,
* for example sending the STOMP CONNECT frame.
*
* <p>The provided {@link #getMessage() message} can be examined to check
* information about the connected user, The session id, and any headers
* sent by the client, for STOMP check the class
* {@link org.springframework.messaging.simp.stomp.StompHeaderAccessor}.
* For example:
*
* <pre class="code">
* StompHeaderAccessor headers = StompHeaderAccessor.wrap(message);
* headers.getSessionId();
* headers.getSessionAttributes();
* headers.getPrincipal();
* </pre>
*
* @author Rossen Stoyanchev
* @since 4.0.3
*/
@SuppressWarnings("serial")
public class SessionConnectEvent extends ApplicationEvent {

private final Message<byte[]> message;


/**
* Create a new SessionConnectEvent.
*
* @param source the component that published the event (never {@code null})
* @param message the connect message
*/
public SessionConnectEvent(Object source, Message<byte[]> message) {
super(source);
Assert.notNull(message, "'message' must not be null");
this.message = message;
}

/**
* Return the connect message.
*/
public Message<byte[]> getMessage() {
return this.message;
}


@Override
public String toString() {
return "SessionConnectEvent: message=" + message;
}
}
@@ -0,0 +1,61 @@
/*
* Copyright 2002-2014 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.springframework.web.socket.messaging;


import org.springframework.context.ApplicationEvent;
import org.springframework.messaging.Message;
import org.springframework.util.Assert;

/**
* A connected event represents the server response to a client's connect request.
* See {@link org.springframework.web.socket.messaging.SessionConnectEvent}.
*
* @author Rossen Stoyanchev
* @since 4.0.3
*/
@SuppressWarnings("serial")
public class SessionConnectedEvent extends ApplicationEvent {

private final Message<byte[]> message;


/**
* Create a new event.
*
* @param source the component that published the event (never {@code null})
* @param message the connected message
*/
public SessionConnectedEvent(Object source, Message<byte[]> message) {
super(source);
Assert.notNull(message, "'message' must not be null");
this.message = message;
}

/**
* Return the connected message.
*/
public Message<byte[]> getMessage() {
return this.message;
}


@Override
public String toString() {
return "SessionConnectedEvent: message=" + message;
}
}
@@ -0,0 +1,74 @@
/*
* Copyright 2002-2014 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.springframework.web.socket.messaging;


import org.springframework.context.ApplicationEvent;
import org.springframework.messaging.Message;
import org.springframework.util.Assert;
import org.springframework.web.socket.CloseStatus;

/**
* Event raised when the session of a WebSocket client using a Simple Messaging
* Protocol (e.g. STOMP) as the WebSocket sub-protocol is closed.
*
* <p>Note that this event may be raised more than once for a single session and
* therefore event consumers should be idempotent and ignore a duplicate event..
*
* @author Rossen Stoyanchev
* @since 4.0.3
*/
@SuppressWarnings("serial")
public class SessionDisconnectEvent extends ApplicationEvent {

private final String sessionId;

private final CloseStatus status;

/**
* Create a new event.
*
* @param source the component that published the event (never {@code null})
* @param sessionId the disconnect message
* @param closeStatus
*/
public SessionDisconnectEvent(Object source, String sessionId, CloseStatus closeStatus) {
super(source);
Assert.notNull(sessionId, "'sessionId' must not be null");
this.sessionId = sessionId;
this.status = closeStatus;
}

/**
* Return the session id.
*/
public String getSessionId() {
return this.sessionId;
}

/**
* Return the status with which the session was closed.
*/
public CloseStatus getCloseStatus() {
return this.status;
}

@Override
public String toString() {
return "SessionDisconnectEvent: sessionId=" + this.sessionId;
}
}
Expand Up @@ -28,6 +28,8 @@
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;

import org.springframework.context.ApplicationEventPublisher;
import org.springframework.context.ApplicationEventPublisherAware;
import org.springframework.messaging.Message;
import org.springframework.messaging.MessageChannel;
import org.springframework.messaging.simp.SimpMessageType;
Expand Down Expand Up @@ -57,7 +59,7 @@
* @author Andy Wilkinson
* @since 4.0
*/
public class StompSubProtocolHandler implements SubProtocolHandler {
public class StompSubProtocolHandler implements SubProtocolHandler, ApplicationEventPublisherAware {

/**
* The name of the header set on the CONNECTED frame indicating the name
Expand All @@ -76,6 +78,8 @@ public class StompSubProtocolHandler implements SubProtocolHandler {

private UserSessionRegistry userSessionRegistry;

private ApplicationEventPublisher eventPublisher;


/**
* Configure the maximum size allowed for an incoming STOMP message.
Expand Down Expand Up @@ -120,6 +124,12 @@ public List<String> getSupportedProtocols() {
return Arrays.asList("v10.stomp", "v11.stomp", "v12.stomp");
}

@Override
public void setApplicationEventPublisher(ApplicationEventPublisher applicationEventPublisher) {
this.eventPublisher = applicationEventPublisher;
}


/**
* Handle incoming WebSocket messages from clients.
*/
Expand Down Expand Up @@ -167,6 +177,11 @@ public void handleMessageFromClient(WebSocketSession session,
headers.setUser(session.getPrincipal());

message = MessageBuilder.withPayload(message.getPayload()).setHeaders(headers).build();

if (SimpMessageType.CONNECT.equals(headers.getMessageType()) && this.eventPublisher != null) {
this.eventPublisher.publishEvent(new SessionConnectEvent(this, message));
}

outputChannel.send(message);
}
catch (Throwable ex) {
Expand Down Expand Up @@ -231,6 +246,11 @@ else if (SimpMessageType.MESSAGE.equals(headers.getMessageType())) {

try {
message = MessageBuilder.withPayload(message.getPayload()).setHeaders(headers).build();

if (headers.getCommand() == StompCommand.CONNECTED && this.eventPublisher != null) {
this.eventPublisher.publishEvent(new SessionConnectedEvent(this, (Message<byte[]>) message));
}

byte[] bytes = this.stompEncoder.encode((Message<byte[]>) message);
TextMessage textMessage = new TextMessage(bytes);

Expand Down Expand Up @@ -329,6 +349,11 @@ public void afterSessionEnded(WebSocketSession session, CloseStatus closeStatus,
StompHeaderAccessor headers = StompHeaderAccessor.create(StompCommand.DISCONNECT);
headers.setSessionId(session.getId());
Message<?> message = MessageBuilder.withPayload(new byte[0]).setHeaders(headers).build();

if (this.eventPublisher != null) {
this.eventPublisher.publishEvent(new SessionDisconnectEvent(this, session.getId(), closeStatus));
}

outputChannel.send(message);
}

Expand Down
Expand Up @@ -28,6 +28,8 @@
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;

import org.springframework.context.ApplicationEventPublisher;
import org.springframework.context.ApplicationEventPublisherAware;
import org.springframework.context.SmartLifecycle;
import org.springframework.messaging.Message;
import org.springframework.messaging.MessageChannel;
Expand Down Expand Up @@ -60,8 +62,8 @@
* @author Andy Wilkinson
* @since 4.0
*/
public class SubProtocolWebSocketHandler
implements WebSocketHandler, SubProtocolCapable, MessageHandler, SmartLifecycle {
public class SubProtocolWebSocketHandler implements WebSocketHandler,
SubProtocolCapable, MessageHandler, SmartLifecycle, ApplicationEventPublisherAware {

private final Log logger = LogFactory.getLog(SubProtocolWebSocketHandler.class);

Expand All @@ -84,6 +86,8 @@ public class SubProtocolWebSocketHandler

private volatile boolean running = false;

private ApplicationEventPublisher eventPublisher;


public SubProtocolWebSocketHandler(MessageChannel clientInboundChannel, SubscribableChannel clientOutboundChannel) {
Assert.notNull(clientInboundChannel, "ClientInboundChannel must not be null");
Expand Down Expand Up @@ -114,18 +118,24 @@ public List<SubProtocolHandler> getProtocolHandlers() {
* Register a sub-protocol handler.
*/
public void addProtocolHandler(SubProtocolHandler handler) {

List<String> protocols = handler.getSupportedProtocols();
if (CollectionUtils.isEmpty(protocols)) {
logger.warn("No sub-protocols, ignoring handler " + handler);
return;
}

for (String protocol: protocols) {
SubProtocolHandler replaced = this.protocolHandlers.put(protocol, handler);
if ((replaced != null) && (replaced != handler) ) {
throw new IllegalStateException("Failed to map handler " + handler
+ " to protocol '" + protocol + "', it is already mapped to handler " + replaced);
}
}

if (handler instanceof ApplicationEventPublisherAware) {
((ApplicationEventPublisherAware) handler).setApplicationEventPublisher(this.eventPublisher);
}
}

/**
Expand Down Expand Up @@ -178,6 +188,10 @@ public int getSendBufferSizeLimit() {
return sendBufferSizeLimit;
}

@Override
public void setApplicationEventPublisher(ApplicationEventPublisher eventPublisher) {
this.eventPublisher = eventPublisher;
}

@Override
public boolean isAutoStartup() {
Expand Down

0 comments on commit 13da705

Please sign in to comment.