Skip to content

Commit

Permalink
Enable passing WebSocket handler instances.
Browse files Browse the repository at this point in the history
  • Loading branch information
jakaarl committed Aug 25, 2016
1 parent 12774f1 commit 11a78b5
Show file tree
Hide file tree
Showing 12 changed files with 142 additions and 54 deletions.
35 changes: 25 additions & 10 deletions src/main/java/spark/Service.java
Expand Up @@ -26,6 +26,9 @@

import spark.embeddedserver.EmbeddedServer;
import spark.embeddedserver.EmbeddedServers;
import spark.embeddedserver.jetty.websocket.WebSocketHandlerClassWrapper;
import spark.embeddedserver.jetty.websocket.WebSocketHandlerInstanceWrapper;
import spark.embeddedserver.jetty.websocket.WebSocketHandlerWrapper;
import spark.route.Routes;
import spark.route.ServletRoutes;
import spark.ssl.SslStores;
Expand Down Expand Up @@ -60,7 +63,7 @@ public final class Service extends Routable {
protected String staticFileFolder = null;
protected String externalStaticFileFolder = null;

protected Map<String, Class<?>> webSocketHandlers = null;
protected Map<String, WebSocketHandlerWrapper> webSocketHandlers = null;

protected int maxThreads = -1;
protected int minThreads = -1;
Expand Down Expand Up @@ -246,30 +249,42 @@ public synchronized Service externalStaticFileLocation(String externalFolder) {
}

/**
* Maps the given path to the given WebSocket handler.
* Maps the given path to the given WebSocket handler class.
* <p>
* This is currently only available in the embedded server mode.
*
* @param path the WebSocket path.
* @param handler the handler class that will manage the WebSocket connection to the given path.
*/
public synchronized void webSocket(String path, Class<?> handler) {
requireNonNull(path, "WebSocket path cannot be null");
requireNonNull(handler, "WebSocket handler class cannot be null");

public void webSocket(String path, Class<?> handlerClass) {
addWebSocketHandler(path, new WebSocketHandlerClassWrapper(handlerClass));
}

/**
* Maps the given path to the given WebSocket handler instance.
* <p>
* This is currently only available in the embedded server mode.
*
* @param path the WebSocket path.
* @param handler the handler instance that will manage the WebSocket connection to the given path.
*/
public void webSocket(String path, Object handler) {
addWebSocketHandler(path, new WebSocketHandlerInstanceWrapper(handler));
}

private synchronized void addWebSocketHandler(String path, WebSocketHandlerWrapper handlerWrapper) {
if (initialized) {
throwBeforeRouteMappingException();
}

if (isRunningFromServlet()) {
throw new IllegalStateException("WebSockets are only supported in the embedded server");
}

}
requireNonNull(path, "WebSocket path cannot be null");
if (webSocketHandlers == null) {
webSocketHandlers = new HashMap<>();
}

webSocketHandlers.put(path, handler);
webSocketHandlers.put(path, handlerWrapper);
}

/**
Expand Down
4 changes: 4 additions & 0 deletions src/main/java/spark/Spark.java
Expand Up @@ -1053,6 +1053,10 @@ public static void stop() {
public static void webSocket(String path, Class<?> handler) {
getInstance().webSocket(path, handler);
}

public static void webSocket(String path, Object handler) {
getInstance().webSocket(path, handler);
}

/**
* Sets the max idle timeout in milliseconds for WebSocket connections.
Expand Down
3 changes: 2 additions & 1 deletion src/main/java/spark/embeddedserver/EmbeddedServer.java
Expand Up @@ -20,6 +20,7 @@
import java.util.Optional;
import java.util.concurrent.CountDownLatch;

import spark.embeddedserver.jetty.websocket.WebSocketHandlerWrapper;
import spark.ssl.SslStores;

/**
Expand Down Expand Up @@ -53,7 +54,7 @@ void ignite(String host,
* @param webSocketHandlers - web socket handlers.
* @param webSocketIdleTimeoutMillis - Optional WebSocket idle timeout (ms).
*/
default void configureWebSockets(Map<String, Class<?>> webSocketHandlers,
default void configureWebSockets(Map<String, WebSocketHandlerWrapper> webSocketHandlers,
Optional<Integer> webSocketIdleTimeoutMillis) {

NotSupportedException.raise(getClass().getSimpleName(), "Web Sockets");
Expand Down
Expand Up @@ -35,6 +35,7 @@

import spark.ssl.SslStores;
import spark.embeddedserver.EmbeddedServer;
import spark.embeddedserver.jetty.websocket.WebSocketHandlerWrapper;
import spark.embeddedserver.jetty.websocket.WebSocketServletContextHandlerFactory;

/**
Expand All @@ -52,15 +53,15 @@ public class EmbeddedJettyServer implements EmbeddedServer {

private final Logger logger = LoggerFactory.getLogger(this.getClass());

private Map<String, Class<?>> webSocketHandlers;
private Map<String, WebSocketHandlerWrapper> webSocketHandlers;
private Optional<Integer> webSocketIdleTimeoutMillis;

public EmbeddedJettyServer(Handler handler) {
this.handler = handler;
}

@Override
public void configureWebSockets(Map<String, Class<?>> webSocketHandlers,
public void configureWebSockets(Map<String, WebSocketHandlerWrapper> webSocketHandlers,
Optional<Integer> webSocketIdleTimeoutMillis) {

this.webSocketHandlers = webSocketHandlers;
Expand Down
Expand Up @@ -15,8 +15,6 @@
*/
package spark.embeddedserver.jetty.websocket;

import org.eclipse.jetty.websocket.api.WebSocketListener;
import org.eclipse.jetty.websocket.api.annotations.WebSocket;
import org.eclipse.jetty.websocket.servlet.ServletUpgradeRequest;
import org.eclipse.jetty.websocket.servlet.ServletUpgradeResponse;
import org.eclipse.jetty.websocket.servlet.WebSocketCreator;
Expand All @@ -32,35 +30,14 @@
public class WebSocketCreatorFactory {

/**
* Creates a {@link WebSocketCreator} that uses the given handler class for
* Creates a {@link WebSocketCreator} that uses the given handler class/instance for
* the WebSocket connections.
*
* @param handlerClass The handler to use to manage WebSocket connections.
* @param handlerWrapper The wrapped handler to use to manage WebSocket connections.
* @return The WebSocketCreator.
*/
public static WebSocketCreator create(Class<?> handlerClass) {
validate(handlerClass);
try {
Object handler = handlerClass.newInstance();
return new SparkWebSocketCreator(handler);
} catch (InstantiationException | IllegalAccessException ex) {
throw new RuntimeException("Could not instantiate websocket handler", ex);
}
}

/**
* Validates that the handler can actually handle the WebSocket connection.
*
* @param handlerClass The handler class to validate.
* @throws IllegalArgumentException if the class is not a valid handler class.
*/
private static void validate(Class<?> handlerClass) {
boolean valid = WebSocketListener.class.isAssignableFrom(handlerClass)
|| handlerClass.isAnnotationPresent(WebSocket.class);
if (!valid) {
throw new IllegalArgumentException(
"WebSocket handler must implement 'WebSocketListener' or be annotated as '@WebSocket'");
}
public static WebSocketCreator create(WebSocketHandlerWrapper handlerWrapper) {
return new SparkWebSocketCreator(handlerWrapper.getHandler());
}

// Package protected to be visible to the unit tests
Expand Down
@@ -0,0 +1,23 @@
package spark.embeddedserver.jetty.websocket;

import static java.util.Objects.requireNonNull;

public class WebSocketHandlerClassWrapper implements WebSocketHandlerWrapper {

private final Class<?> handlerClass;

public WebSocketHandlerClassWrapper(Class<?> handlerClass) {
requireNonNull(handlerClass, "WebSocket handler class cannot be null");
WebSocketHandlerWrapper.validateHandlerClass(handlerClass);
this.handlerClass = handlerClass;
}
@Override
public Object getHandler() {
try {
return handlerClass.newInstance();
} catch (InstantiationException | IllegalAccessException ex) {
throw new RuntimeException("Could not instantiate websocket handler", ex);
}
}

}
@@ -0,0 +1,20 @@
package spark.embeddedserver.jetty.websocket;

import static java.util.Objects.requireNonNull;

public class WebSocketHandlerInstanceWrapper implements WebSocketHandlerWrapper {

private final Object handler;

public WebSocketHandlerInstanceWrapper(Object handler) {
requireNonNull(handler, "WebSocket handler cannot be null");
WebSocketHandlerWrapper.validateHandlerClass(handler.getClass());
this.handler = handler;
}

@Override
public Object getHandler() {
return handler;
}

}
@@ -0,0 +1,27 @@
package spark.embeddedserver.jetty.websocket;

import org.eclipse.jetty.websocket.api.WebSocketListener;
import org.eclipse.jetty.websocket.api.annotations.WebSocket;

/**
* A wrapper for web socket handler classes/instances.
*/
public interface WebSocketHandlerWrapper {

/**
* Gets the actual handler - if necessary, instantiating an object.
*
* @return The handler instance.
*/
Object getHandler();

static void validateHandlerClass(Class<?> handlerClass) {
boolean valid = WebSocketListener.class.isAssignableFrom(handlerClass)
|| handlerClass.isAnnotationPresent(WebSocket.class);
if (!valid) {
throw new IllegalArgumentException(
"WebSocket handler must implement 'WebSocketListener' or be annotated as '@WebSocket'");
}
}

}
Expand Up @@ -40,7 +40,7 @@ public class WebSocketServletContextHandlerFactory {
* @param webSocketIdleTimeoutMillis webSocketIdleTimeoutMillis
* @return a new websocket servlet context handler or 'null' if creation failed.
*/
public static ServletContextHandler create(Map<String, Class<?>> webSocketHandlers,
public static ServletContextHandler create(Map<String, WebSocketHandlerWrapper> webSocketHandlers,
Optional<Integer> webSocketIdleTimeoutMillis) {
ServletContextHandler webSocketServletContextHandler = null;
if (webSocketHandlers != null) {
Expand Down
21 changes: 20 additions & 1 deletion src/test/java/spark/ServiceTest.java
Expand Up @@ -2,6 +2,7 @@

import javax.servlet.http.HttpServletResponse;

import org.eclipse.jetty.websocket.api.annotations.WebSocket;
import org.junit.Before;
import org.junit.Rule;
import org.junit.Test;
Expand Down Expand Up @@ -191,6 +192,24 @@ public void testWebSocket_whenInitializedTrue_thenThrowIllegalStateException() {
thrown.expectMessage("This must be done before route mapping has begun");

Whitebox.setInternalState(service, "initialized", true);
service.webSocket("/", Object.class);
service.webSocket("/", DummyWebSocketListener.class);
}

@Test
public void testWebSocket_whenPathNull_thenThrowNullPointerException() {
thrown.expect(NullPointerException.class);
thrown.expectMessage("WebSocket path cannot be null");
service.webSocket(null, new DummyWebSocketListener());
}

@Test
public void testWebSocket_whenHandlerNull_thenThrowNullPointerException() {
thrown.expect(NullPointerException.class);
thrown.expectMessage("WebSocket handler class cannot be null");
service.webSocket("/", null);
}

@WebSocket
protected static class DummyWebSocketListener {
}
}
Expand Up @@ -15,19 +15,21 @@ public class WebSocketCreatorFactoryTest {

@Test
public void testCreateWebSocketHandler() {
WebSocketCreator annotated = WebSocketCreatorFactory.create(AnnotatedHandler.class);
WebSocketCreator annotated =
WebSocketCreatorFactory.create(new WebSocketHandlerClassWrapper(AnnotatedHandler.class));
assertTrue(annotated instanceof SparkWebSocketCreator);
assertTrue(SparkWebSocketCreator.class.cast(annotated).getHandler() instanceof AnnotatedHandler);

WebSocketCreator listener = WebSocketCreatorFactory.create(ListenerHandler.class);
WebSocketCreator listener =
WebSocketCreatorFactory.create(new WebSocketHandlerClassWrapper(ListenerHandler.class));
assertTrue(listener instanceof SparkWebSocketCreator);
assertTrue(SparkWebSocketCreator.class.cast(listener).getHandler() instanceof ListenerHandler);
}

@Test
public void testCannotCreateInvalidHandlers() {
try {
WebSocketCreatorFactory.create(InvalidHandler.class);
WebSocketCreatorFactory.create(new WebSocketHandlerClassWrapper(InvalidHandler.class));
fail("Handler creation should have thrown an IllegalArgumentException");
} catch (IllegalArgumentException ex) {
assertEquals(
Expand All @@ -38,9 +40,8 @@ public void testCannotCreateInvalidHandlers() {

@Test
public void testCreate_whenInstantiationException() throws Exception {

try {
WebSocketCreatorFactory.create(FailingHandler.class);
WebSocketCreatorFactory.create(new WebSocketHandlerClassWrapper(FailingHandler.class));
fail("Handler creation should have thrown a RunTimeException");
} catch(RuntimeException ex) {
assertEquals("Could not instantiate websocket handler", ex.getMessage());
Expand Down
Expand Up @@ -36,9 +36,9 @@ public void testCreate_whenWebSocketHandlersIsNull_thenReturnNull() throws Excep
@Test
public void testCreate_whenNoIdleTimeoutIsPresent() throws Exception {

Map<String, Class<?>> webSocketHandlers = new HashMap<>();
Map<String, WebSocketHandlerWrapper> webSocketHandlers = new HashMap<>();

webSocketHandlers.put(webSocketPath, WebSocketTestHandler.class);
webSocketHandlers.put(webSocketPath, new WebSocketHandlerClassWrapper(WebSocketTestHandler.class));

servletContextHandler = WebSocketServletContextHandlerFactory.create(webSocketHandlers, Optional.empty());

Expand All @@ -64,9 +64,9 @@ public void testCreate_whenTimeoutIsPresent() throws Exception {

final Integer timeout = Integer.valueOf(1000);

Map<String, Class<?>> webSocketHandlers = new HashMap<>();
Map<String, WebSocketHandlerWrapper> webSocketHandlers = new HashMap<>();

webSocketHandlers.put(webSocketPath, WebSocketTestHandler.class);
webSocketHandlers.put(webSocketPath, new WebSocketHandlerClassWrapper(WebSocketTestHandler.class));

servletContextHandler = WebSocketServletContextHandlerFactory.create(webSocketHandlers, Optional.of(timeout));

Expand Down Expand Up @@ -94,11 +94,11 @@ public void testCreate_whenTimeoutIsPresent() throws Exception {
@PrepareForTest(WebSocketServletContextHandlerFactory.class)
public void testCreate_whenWebSocketContextHandlerCreationFails_thenThrowException() throws Exception {

Map<String, Class<?>> webSocketHandlers = new HashMap<>();

PowerMockito.whenNew(ServletContextHandler.class).withAnyArguments().thenThrow(new Exception(""));

webSocketHandlers.put(webSocketPath, WebSocketTestHandler.class);
Map<String, WebSocketHandlerWrapper> webSocketHandlers = new HashMap<>();

webSocketHandlers.put(webSocketPath, new WebSocketHandlerClassWrapper(WebSocketTestHandler.class));

servletContextHandler = WebSocketServletContextHandlerFactory.create(webSocketHandlers, Optional.empty());

Expand Down

0 comments on commit 11a78b5

Please sign in to comment.