Skip to content

Commit

Permalink
UNDERTOW-1249 Unable to decode Binary Messages if Text Decoder is set
Browse files Browse the repository at this point in the history
  • Loading branch information
stuartwdouglas committed Jan 4, 2018
1 parent aff98b5 commit 0ed35ca
Show file tree
Hide file tree
Showing 4 changed files with 83 additions and 24 deletions.
Expand Up @@ -39,7 +39,10 @@
import java.io.Reader; import java.io.Reader;
import java.io.StringReader; import java.io.StringReader;
import java.nio.ByteBuffer; import java.nio.ByteBuffer;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashSet; import java.util.HashSet;
import java.util.List;
import java.util.Map; import java.util.Map;
import java.util.Map.Entry; import java.util.Map.Entry;
import java.util.Set; import java.util.Set;
Expand Down Expand Up @@ -376,47 +379,55 @@ public final void addHandler(MessageHandler handler) {
private void addHandlerInternal(MessageHandler handler, Class<?> type, boolean partial) { private void addHandlerInternal(MessageHandler handler, Class<?> type, boolean partial) {
verify(type, handler); verify(type, handler);


HandlerWrapper handlerWrapper = createHandlerWrapper(type, handler, partial); List<HandlerWrapper> handlerWrappers = createHandlerWrappers(type, handler, partial);

for(HandlerWrapper handlerWrapper : handlerWrappers) {
if (handlers.containsKey(handlerWrapper.getFrameType())) { if (handlers.containsKey(handlerWrapper.getFrameType())) {
throw JsrWebSocketMessages.MESSAGES.handlerAlreadyRegistered(handlerWrapper.getFrameType());
} else {
if (handlers.putIfAbsent(handlerWrapper.getFrameType(), handlerWrapper) != null) {
throw JsrWebSocketMessages.MESSAGES.handlerAlreadyRegistered(handlerWrapper.getFrameType()); throw JsrWebSocketMessages.MESSAGES.handlerAlreadyRegistered(handlerWrapper.getFrameType());
} else {
if (handlers.putIfAbsent(handlerWrapper.getFrameType(), handlerWrapper) != null) {
throw JsrWebSocketMessages.MESSAGES.handlerAlreadyRegistered(handlerWrapper.getFrameType());
}
} }
} }
} }


/** /**
* Return the {@link FrameType} for the given {@link Class}. * Return the {@link FrameType} for the given {@link Class}.
*
* Note that multiple wrappers can be returned if both text and binary frames can be decoded to the given type
*/ */
protected HandlerWrapper createHandlerWrapper(Class<?> type, MessageHandler handler, boolean partialHandler) { protected List<HandlerWrapper> createHandlerWrappers(Class<?> type, MessageHandler handler, boolean partialHandler) {
//check the encodings first //check the encodings first
Encoding encoding = session.getEncoding(); Encoding encoding = session.getEncoding();
List<HandlerWrapper> ret = new ArrayList<>(2);
if (encoding.canDecodeText(type)) { if (encoding.canDecodeText(type)) {
return new HandlerWrapper(FrameType.TEXT, handler, type, true, false); ret.add(new HandlerWrapper(FrameType.TEXT, handler, type, true, false));
} else if (encoding.canDecodeBinary(type)) { }
return new HandlerWrapper(FrameType.BYTE, handler, type, true, false); if (encoding.canDecodeBinary(type)) {
ret.add(new HandlerWrapper(FrameType.BYTE, handler, type, true, false));
}
if(!ret.isEmpty()) {
return ret;
} }
if (partialHandler) { if (partialHandler) {
// Partial message handler supports only String, byte[] and ByteBuffer. // Partial message handler supports only String, byte[] and ByteBuffer.
// See JavaDocs of the MessageHandler.Partial interface. // See JavaDocs of the MessageHandler.Partial interface.
if (type == String.class) { if (type == String.class) {
return new HandlerWrapper(FrameType.TEXT, handler, type, false, true); return Collections.singletonList(new HandlerWrapper(FrameType.TEXT, handler, type, false, true));
} }
if (type == byte[].class || type == ByteBuffer.class) { if (type == byte[].class || type == ByteBuffer.class) {
return new HandlerWrapper(FrameType.BYTE, handler, type, false, true); return Collections.singletonList(new HandlerWrapper(FrameType.BYTE, handler, type, false, true));
} }
throw JsrWebSocketMessages.MESSAGES.unsupportedFrameType(type); throw JsrWebSocketMessages.MESSAGES.unsupportedFrameType(type);
} }
if (type == byte[].class || type == ByteBuffer.class || type == InputStream.class) { if (type == byte[].class || type == ByteBuffer.class || type == InputStream.class) {
return new HandlerWrapper(FrameType.BYTE, handler, type, false, false); return Collections.singletonList(new HandlerWrapper(FrameType.BYTE, handler, type, false, false));
} }
if (type == String.class || type == Reader.class) { if (type == String.class || type == Reader.class) {
return new HandlerWrapper(FrameType.TEXT, handler, type, false, false); return Collections.singletonList(new HandlerWrapper(FrameType.TEXT, handler, type, false, false));
} }
if (type == PongMessage.class) { if (type == PongMessage.class) {
return new HandlerWrapper(FrameType.PONG, handler, type, false, false); return Collections.singletonList(new HandlerWrapper(FrameType.PONG, handler, type, false, false));
} }
throw JsrWebSocketMessages.MESSAGES.unsupportedFrameType(type); throw JsrWebSocketMessages.MESSAGES.unsupportedFrameType(type);
} }
Expand All @@ -432,11 +443,13 @@ public final void removeHandler(MessageHandler handler) {
Map<Class<?>, Boolean> types = ClassUtils.getHandlerTypes(handler.getClass()); Map<Class<?>, Boolean> types = ClassUtils.getHandlerTypes(handler.getClass());
for (Entry<Class<?>, Boolean> e : types.entrySet()) { for (Entry<Class<?>, Boolean> e : types.entrySet()) {
Class<?> type = e.getKey(); Class<?> type = e.getKey();
HandlerWrapper handlerWrapper = createHandlerWrapper(type, handler, e.getValue()); List<HandlerWrapper> handlerWrappers = createHandlerWrappers(type, handler, e.getValue());
FrameType frameType = handlerWrapper.getFrameType(); for(HandlerWrapper handlerWrapper : handlerWrappers) {
HandlerWrapper wrapper = handlers.get(frameType); FrameType frameType = handlerWrapper.getFrameType();
if (wrapper != null && wrapper.getMessageType() == type) { HandlerWrapper wrapper = handlers.get(frameType);
handlers.remove(frameType, wrapper); if (wrapper != null && wrapper.getMessageType() == type) {
handlers.remove(frameType, wrapper);
}
} }
} }
} }
Expand Down
Expand Up @@ -41,6 +41,7 @@




import io.netty.buffer.Unpooled; import io.netty.buffer.Unpooled;
import io.netty.handler.codec.http.websocketx.BinaryWebSocketFrame;
import io.netty.handler.codec.http.websocketx.TextWebSocketFrame; import io.netty.handler.codec.http.websocketx.TextWebSocketFrame;
import io.netty.handler.codec.http.websocketx.WebSocketVersion; import io.netty.handler.codec.http.websocketx.WebSocketVersion;
import io.undertow.Handlers; import io.undertow.Handlers;
Expand Down Expand Up @@ -281,7 +282,7 @@ public void testImplicitIntegerConversion() throws Exception {




@Test @Test
public void testEncodingAndDecoding() throws Exception { public void testEncodingAndDecodingText() throws Exception {
final byte[] payload = "hello".getBytes(); final byte[] payload = "hello".getBytes();
final FutureResult latch = new FutureResult(); final FutureResult latch = new FutureResult();


Expand All @@ -291,6 +292,17 @@ public void testEncodingAndDecoding() throws Exception {
latch.getIoFuture().get(); latch.getIoFuture().get();
client.destroy(); client.destroy();
} }
@Test
public void testEncodingAndDecodingBinary() throws Exception {
final byte[] payload = "hello".getBytes();
final FutureResult latch = new FutureResult();

WebSocketTestClient client = new WebSocketTestClient(WebSocketVersion.V13, new URI("ws://" + DefaultServer.getHostAddress("default") + ":" + DefaultServer.getHostPort("default") + "/ws/encoding/Stuart"));
client.connect();
client.send(new BinaryWebSocketFrame(Unpooled.wrappedBuffer(payload)), new FrameChecker(TextWebSocketFrame.class, "hello Stuart".getBytes(), latch));
latch.getIoFuture().get();
client.destroy();
}


@Test @Test
public void testEncodingWithGenericSuperclass() throws Exception { public void testEncodingWithGenericSuperclass() throws Exception {
Expand Down
Expand Up @@ -18,6 +18,9 @@


package io.undertow.websockets.jsr.test.annotated; package io.undertow.websockets.jsr.test.annotated;


import java.nio.ByteBuffer;
import java.nio.charset.StandardCharsets;

import javax.websocket.DecodeException; import javax.websocket.DecodeException;
import javax.websocket.EncodeException; import javax.websocket.EncodeException;
import javax.websocket.EndpointConfig; import javax.websocket.EndpointConfig;
Expand All @@ -37,7 +40,7 @@ public String getValue() {
return value; return value;
} }


public static class Encoder implements javax.websocket.Encoder.Text<EncodableObject> { public static class TextEncoder implements javax.websocket.Encoder.Text<EncodableObject> {


boolean initalized = false; boolean initalized = false;
public static volatile boolean destroyed = false; public static volatile boolean destroyed = false;
Expand All @@ -61,7 +64,7 @@ public void destroy() {
} }
} }


public static class Decoder implements javax.websocket.Decoder.Text<EncodableObject> { public static class TextDecoder implements javax.websocket.Decoder.Text<EncodableObject> {


boolean initalized = false; boolean initalized = false;
public static volatile boolean destroyed = false; public static volatile boolean destroyed = false;
Expand Down Expand Up @@ -89,4 +92,35 @@ public boolean willDecode(final String s) {
return true; return true;
} }
} }

public static class BinaryDecoder implements javax.websocket.Decoder.Binary<EncodableObject> {

boolean initalized = false;
public static volatile boolean destroyed = false;

@Override
public void init(final EndpointConfig config) {
initalized = true;
}

@Override
public void destroy() {
destroyed = true;
}

@Override
public EncodableObject decode(final ByteBuffer s) throws DecodeException {
if(!initalized) {
throw new DecodeException(s, "not initialized");
}
byte[] data = new byte[s.remaining()];
s.get(data);
return new EncodableObject(new String(data, StandardCharsets.US_ASCII));
}

@Override
public boolean willDecode(final ByteBuffer s) {
return true;
}
}
} }
Expand Up @@ -25,7 +25,7 @@
/** /**
* @author Stuart Douglas * @author Stuart Douglas
*/ */
@ServerEndpoint(value = "/encoding/{user}", encoders = EncodableObject.Encoder.class, decoders = EncodableObject.Decoder.class) @ServerEndpoint(value = "/encoding/{user}", encoders = {EncodableObject.TextEncoder.class}, decoders = {EncodableObject.TextDecoder.class, EncodableObject.BinaryDecoder.class})
public class EncodingEndpoint { public class EncodingEndpoint {


@OnMessage @OnMessage
Expand Down

0 comments on commit 0ed35ca

Please sign in to comment.