diff --git a/websockets-jsr/src/main/java/io/undertow/websockets/jsr/FrameHandler.java b/websockets-jsr/src/main/java/io/undertow/websockets/jsr/FrameHandler.java index 4d0c49f67d..12ae5bb38c 100644 --- a/websockets-jsr/src/main/java/io/undertow/websockets/jsr/FrameHandler.java +++ b/websockets-jsr/src/main/java/io/undertow/websockets/jsr/FrameHandler.java @@ -39,7 +39,10 @@ import java.io.Reader; import java.io.StringReader; import java.nio.ByteBuffer; +import java.util.ArrayList; +import java.util.Collections; import java.util.HashSet; +import java.util.List; import java.util.Map; import java.util.Map.Entry; import java.util.Set; @@ -376,47 +379,55 @@ public final void addHandler(MessageHandler handler) { private void addHandlerInternal(MessageHandler handler, Class type, boolean partial) { verify(type, handler); - HandlerWrapper handlerWrapper = createHandlerWrapper(type, handler, partial); - - if (handlers.containsKey(handlerWrapper.getFrameType())) { - throw JsrWebSocketMessages.MESSAGES.handlerAlreadyRegistered(handlerWrapper.getFrameType()); - } else { - if (handlers.putIfAbsent(handlerWrapper.getFrameType(), handlerWrapper) != null) { + List handlerWrappers = createHandlerWrappers(type, handler, partial); + for(HandlerWrapper handlerWrapper : handlerWrappers) { + if (handlers.containsKey(handlerWrapper.getFrameType())) { throw JsrWebSocketMessages.MESSAGES.handlerAlreadyRegistered(handlerWrapper.getFrameType()); + } else { + if (handlers.putIfAbsent(handlerWrapper.getFrameType(), handlerWrapper) != null) { + throw JsrWebSocketMessages.MESSAGES.handlerAlreadyRegistered(handlerWrapper.getFrameType()); + } } } } /** * Return the {@link FrameType} for the given {@link Class}. + * + * Note that multiple wrappers can be returned if both text and binary frames can be decoded to the given type */ - protected HandlerWrapper createHandlerWrapper(Class type, MessageHandler handler, boolean partialHandler) { + protected List createHandlerWrappers(Class type, MessageHandler handler, boolean partialHandler) { //check the encodings first Encoding encoding = session.getEncoding(); + List ret = new ArrayList<>(2); if (encoding.canDecodeText(type)) { - return new HandlerWrapper(FrameType.TEXT, handler, type, true, false); - } else if (encoding.canDecodeBinary(type)) { - return new HandlerWrapper(FrameType.BYTE, handler, type, true, false); + ret.add(new HandlerWrapper(FrameType.TEXT, handler, type, true, false)); + } + if (encoding.canDecodeBinary(type)) { + ret.add(new HandlerWrapper(FrameType.BYTE, handler, type, true, false)); + } + if(!ret.isEmpty()) { + return ret; } if (partialHandler) { // Partial message handler supports only String, byte[] and ByteBuffer. // See JavaDocs of the MessageHandler.Partial interface. if (type == String.class) { - return new HandlerWrapper(FrameType.TEXT, handler, type, false, true); + return Collections.singletonList(new HandlerWrapper(FrameType.TEXT, handler, type, false, true)); } if (type == byte[].class || type == ByteBuffer.class) { - return new HandlerWrapper(FrameType.BYTE, handler, type, false, true); + return Collections.singletonList(new HandlerWrapper(FrameType.BYTE, handler, type, false, true)); } throw JsrWebSocketMessages.MESSAGES.unsupportedFrameType(type); } if (type == byte[].class || type == ByteBuffer.class || type == InputStream.class) { - return new HandlerWrapper(FrameType.BYTE, handler, type, false, false); + return Collections.singletonList(new HandlerWrapper(FrameType.BYTE, handler, type, false, false)); } if (type == String.class || type == Reader.class) { - return new HandlerWrapper(FrameType.TEXT, handler, type, false, false); + return Collections.singletonList(new HandlerWrapper(FrameType.TEXT, handler, type, false, false)); } if (type == PongMessage.class) { - return new HandlerWrapper(FrameType.PONG, handler, type, false, false); + return Collections.singletonList(new HandlerWrapper(FrameType.PONG, handler, type, false, false)); } throw JsrWebSocketMessages.MESSAGES.unsupportedFrameType(type); } @@ -432,11 +443,13 @@ public final void removeHandler(MessageHandler handler) { Map, Boolean> types = ClassUtils.getHandlerTypes(handler.getClass()); for (Entry, Boolean> e : types.entrySet()) { Class type = e.getKey(); - HandlerWrapper handlerWrapper = createHandlerWrapper(type, handler, e.getValue()); - FrameType frameType = handlerWrapper.getFrameType(); - HandlerWrapper wrapper = handlers.get(frameType); - if (wrapper != null && wrapper.getMessageType() == type) { - handlers.remove(frameType, wrapper); + List handlerWrappers = createHandlerWrappers(type, handler, e.getValue()); + for(HandlerWrapper handlerWrapper : handlerWrappers) { + FrameType frameType = handlerWrapper.getFrameType(); + HandlerWrapper wrapper = handlers.get(frameType); + if (wrapper != null && wrapper.getMessageType() == type) { + handlers.remove(frameType, wrapper); + } } } } diff --git a/websockets-jsr/src/test/java/io/undertow/websockets/jsr/test/annotated/AnnotatedEndpointTest.java b/websockets-jsr/src/test/java/io/undertow/websockets/jsr/test/annotated/AnnotatedEndpointTest.java index 5ebb0256df..fe682199c8 100644 --- a/websockets-jsr/src/test/java/io/undertow/websockets/jsr/test/annotated/AnnotatedEndpointTest.java +++ b/websockets-jsr/src/test/java/io/undertow/websockets/jsr/test/annotated/AnnotatedEndpointTest.java @@ -41,6 +41,7 @@ import io.netty.buffer.Unpooled; +import io.netty.handler.codec.http.websocketx.BinaryWebSocketFrame; import io.netty.handler.codec.http.websocketx.TextWebSocketFrame; import io.netty.handler.codec.http.websocketx.WebSocketVersion; import io.undertow.Handlers; @@ -281,7 +282,7 @@ public void testImplicitIntegerConversion() throws Exception { @Test - public void testEncodingAndDecoding() throws Exception { + public void testEncodingAndDecodingText() throws Exception { final byte[] payload = "hello".getBytes(); final FutureResult latch = new FutureResult(); @@ -291,6 +292,17 @@ public void testEncodingAndDecoding() throws Exception { latch.getIoFuture().get(); client.destroy(); } + @Test + public void testEncodingAndDecodingBinary() throws Exception { + final byte[] payload = "hello".getBytes(); + final FutureResult latch = new FutureResult(); + + WebSocketTestClient client = new WebSocketTestClient(WebSocketVersion.V13, new URI("ws://" + DefaultServer.getHostAddress("default") + ":" + DefaultServer.getHostPort("default") + "/ws/encoding/Stuart")); + client.connect(); + client.send(new BinaryWebSocketFrame(Unpooled.wrappedBuffer(payload)), new FrameChecker(TextWebSocketFrame.class, "hello Stuart".getBytes(), latch)); + latch.getIoFuture().get(); + client.destroy(); + } @Test public void testEncodingWithGenericSuperclass() throws Exception { diff --git a/websockets-jsr/src/test/java/io/undertow/websockets/jsr/test/annotated/EncodableObject.java b/websockets-jsr/src/test/java/io/undertow/websockets/jsr/test/annotated/EncodableObject.java index 8ada1ae704..b71222fe33 100644 --- a/websockets-jsr/src/test/java/io/undertow/websockets/jsr/test/annotated/EncodableObject.java +++ b/websockets-jsr/src/test/java/io/undertow/websockets/jsr/test/annotated/EncodableObject.java @@ -18,6 +18,9 @@ package io.undertow.websockets.jsr.test.annotated; +import java.nio.ByteBuffer; +import java.nio.charset.StandardCharsets; + import javax.websocket.DecodeException; import javax.websocket.EncodeException; import javax.websocket.EndpointConfig; @@ -37,7 +40,7 @@ public String getValue() { return value; } - public static class Encoder implements javax.websocket.Encoder.Text { + public static class TextEncoder implements javax.websocket.Encoder.Text { boolean initalized = false; public static volatile boolean destroyed = false; @@ -61,7 +64,7 @@ public void destroy() { } } - public static class Decoder implements javax.websocket.Decoder.Text { + public static class TextDecoder implements javax.websocket.Decoder.Text { boolean initalized = false; public static volatile boolean destroyed = false; @@ -89,4 +92,35 @@ public boolean willDecode(final String s) { return true; } } + + public static class BinaryDecoder implements javax.websocket.Decoder.Binary { + + boolean initalized = false; + public static volatile boolean destroyed = false; + + @Override + public void init(final EndpointConfig config) { + initalized = true; + } + + @Override + public void destroy() { + destroyed = true; + } + + @Override + public EncodableObject decode(final ByteBuffer s) throws DecodeException { + if(!initalized) { + throw new DecodeException(s, "not initialized"); + } + byte[] data = new byte[s.remaining()]; + s.get(data); + return new EncodableObject(new String(data, StandardCharsets.US_ASCII)); + } + + @Override + public boolean willDecode(final ByteBuffer s) { + return true; + } + } } diff --git a/websockets-jsr/src/test/java/io/undertow/websockets/jsr/test/annotated/EncodingEndpoint.java b/websockets-jsr/src/test/java/io/undertow/websockets/jsr/test/annotated/EncodingEndpoint.java index 864e934d62..008ce03990 100644 --- a/websockets-jsr/src/test/java/io/undertow/websockets/jsr/test/annotated/EncodingEndpoint.java +++ b/websockets-jsr/src/test/java/io/undertow/websockets/jsr/test/annotated/EncodingEndpoint.java @@ -25,7 +25,7 @@ /** * @author Stuart Douglas */ -@ServerEndpoint(value = "/encoding/{user}", encoders = EncodableObject.Encoder.class, decoders = EncodableObject.Decoder.class) +@ServerEndpoint(value = "/encoding/{user}", encoders = {EncodableObject.TextEncoder.class}, decoders = {EncodableObject.TextDecoder.class, EncodableObject.BinaryDecoder.class}) public class EncodingEndpoint { @OnMessage