Skip to content

Commit

Permalink
ByteBuffer handling for Jetty WebSocket messages
Browse files Browse the repository at this point in the history
Closes gh-31182
  • Loading branch information
rstoyanchev committed Sep 12, 2023
1 parent f51838b commit ed172d6
Show file tree
Hide file tree
Showing 2 changed files with 238 additions and 31 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,10 @@
package org.springframework.web.reactive.socket.adapter;

import java.nio.ByteBuffer;
import java.nio.charset.Charset;
import java.nio.charset.StandardCharsets;
import java.util.function.Function;
import java.util.function.IntPredicate;

import org.eclipse.jetty.websocket.api.Callback;
import org.eclipse.jetty.websocket.api.Frame;
Expand All @@ -31,14 +33,15 @@
import org.eclipse.jetty.websocket.api.annotations.WebSocket;
import org.eclipse.jetty.websocket.core.OpCode;

import org.springframework.core.io.buffer.CloseableDataBuffer;
import org.springframework.core.io.buffer.DataBuffer;
import org.springframework.core.io.buffer.DataBufferFactory;
import org.springframework.lang.Nullable;
import org.springframework.util.Assert;
import org.springframework.web.reactive.socket.CloseStatus;
import org.springframework.web.reactive.socket.WebSocketHandler;
import org.springframework.web.reactive.socket.WebSocketMessage;
import org.springframework.web.reactive.socket.WebSocketMessage.Type;
import org.springframework.web.reactive.socket.WebSocketSession;

/**
* Jetty {@link WebSocket @WebSocket} handler that delegates events to a
Expand Down Expand Up @@ -83,53 +86,36 @@ public void onWebSocketOpen(Session session) {
@OnWebSocketMessage
public void onWebSocketText(String message) {
if (this.delegateSession != null) {
WebSocketMessage webSocketMessage = toMessage(Type.TEXT, message);
byte[] bytes = message.getBytes(StandardCharsets.UTF_8);
DataBuffer buffer = this.delegateSession.bufferFactory().wrap(bytes);
WebSocketMessage webSocketMessage = new WebSocketMessage(Type.TEXT, buffer);
this.delegateSession.handleMessage(webSocketMessage.getType(), webSocketMessage);
}
}

@OnWebSocketMessage
public void onWebSocketBinary(ByteBuffer buffer, Callback callback) {
public void onWebSocketBinary(ByteBuffer byteBuffer, Callback callback) {
if (this.delegateSession != null) {
WebSocketMessage webSocketMessage = toMessage(Type.BINARY, buffer);
DataBuffer buffer = this.delegateSession.bufferFactory().wrap(byteBuffer);
buffer = new JettyDataBuffer(buffer, callback);
WebSocketMessage webSocketMessage = new WebSocketMessage(Type.BINARY, buffer);
this.delegateSession.handleMessage(webSocketMessage.getType(), webSocketMessage);
callback.succeed();
}
}

@OnWebSocketFrame
public void onWebSocketFrame(Frame frame, Callback callback) {
if (this.delegateSession != null) {
if (OpCode.PONG == frame.getOpCode()) {
ByteBuffer buffer = (frame.getPayload() != null ? frame.getPayload() : EMPTY_PAYLOAD);
WebSocketMessage webSocketMessage = toMessage(Type.PONG, buffer);
ByteBuffer byteBuffer = (frame.getPayload() != null ? frame.getPayload() : EMPTY_PAYLOAD);
DataBuffer buffer = this.delegateSession.bufferFactory().wrap(byteBuffer);
buffer = new JettyDataBuffer(buffer, callback);
WebSocketMessage webSocketMessage = new WebSocketMessage(Type.PONG, buffer);
this.delegateSession.handleMessage(webSocketMessage.getType(), webSocketMessage);
callback.succeed();
}
}
}

private <T> WebSocketMessage toMessage(Type type, T message) {
WebSocketSession session = this.delegateSession;
Assert.state(session != null, "Cannot create message without a session");
if (Type.TEXT.equals(type)) {
byte[] bytes = ((String) message).getBytes(StandardCharsets.UTF_8);
DataBuffer buffer = session.bufferFactory().wrap(bytes);
return new WebSocketMessage(Type.TEXT, buffer);
}
else if (Type.BINARY.equals(type)) {
DataBuffer buffer = session.bufferFactory().wrap((ByteBuffer) message);
return new WebSocketMessage(Type.BINARY, buffer);
}
else if (Type.PONG.equals(type)) {
DataBuffer buffer = session.bufferFactory().wrap((ByteBuffer) message);
return new WebSocketMessage(Type.PONG, buffer);
}
else {
throw new IllegalArgumentException("Unexpected message type: " + message);
}
}

@OnWebSocketClose
public void onWebSocketClose(int statusCode, String reason) {
if (this.delegateSession != null) {
Expand All @@ -144,4 +130,215 @@ public void onWebSocketError(Throwable cause) {
}
}


private static final class JettyDataBuffer implements CloseableDataBuffer {

private final DataBuffer delegate;

private final Callback callback;

public JettyDataBuffer(DataBuffer delegate, Callback callback) {
Assert.notNull(delegate, "'delegate` must not be null");
Assert.notNull(callback, "Callback must not be null");
this.delegate = delegate;
this.callback = callback;
}

@Override
public void close() {
this.callback.succeed();
}

// delegation

@Override
public DataBufferFactory factory() {
return this.delegate.factory();
}

@Override
public int indexOf(IntPredicate predicate, int fromIndex) {
return this.delegate.indexOf(predicate, fromIndex);
}

@Override
public int lastIndexOf(IntPredicate predicate, int fromIndex) {
return this.delegate.lastIndexOf(predicate, fromIndex);
}

@Override
public int readableByteCount() {
return this.delegate.readableByteCount();
}

@Override
public int writableByteCount() {
return this.delegate.writableByteCount();
}

@Override
public int capacity() {
return this.delegate.capacity();
}

@Override
@Deprecated
public DataBuffer capacity(int capacity) {
this.delegate.capacity(capacity);
return this;
}

@Override
public DataBuffer ensureWritable(int capacity) {
this.delegate.ensureWritable(capacity);
return this;
}

@Override
public int readPosition() {
return this.delegate.readPosition();
}

@Override
public DataBuffer readPosition(int readPosition) {
this.delegate.readPosition(readPosition);
return this;
}

@Override
public int writePosition() {
return this.delegate.writePosition();
}

@Override
public DataBuffer writePosition(int writePosition) {
this.delegate.writePosition(writePosition);
return this;
}

@Override
public byte getByte(int index) {
return this.delegate.getByte(index);
}

@Override
public byte read() {
return this.delegate.read();
}

@Override
public DataBuffer read(byte[] destination) {
this.delegate.read(destination);
return this;
}

@Override
public DataBuffer read(byte[] destination, int offset, int length) {
this.delegate.read(destination, offset, length);
return this;
}

@Override
public DataBuffer write(byte b) {
this.delegate.write(b);
return this;
}

@Override
public DataBuffer write(byte[] source) {
this.delegate.write(source);
return this;
}

@Override
public DataBuffer write(byte[] source, int offset, int length) {
this.delegate.write(source, offset, length);
return this;
}

@Override
public DataBuffer write(DataBuffer... buffers) {
this.delegate.write(buffers);
return this;
}

@Override
public DataBuffer write(ByteBuffer... buffers) {
this.delegate.write(buffers);
return this;
}

@Override
@Deprecated
public DataBuffer slice(int index, int length) {
DataBuffer delegateSlice = this.delegate.slice(index, length);
return new JettyDataBuffer(delegateSlice, this.callback);
}

@Override
public DataBuffer split(int index) {
DataBuffer delegateSplit = this.delegate.split(index);
return new JettyDataBuffer(delegateSplit, this.callback);
}

@Override
@Deprecated
public ByteBuffer asByteBuffer() {
return this.delegate.asByteBuffer();
}

@Override
@Deprecated
public ByteBuffer asByteBuffer(int index, int length) {
return this.delegate.asByteBuffer(index, length);
}

@Override
@Deprecated
public ByteBuffer toByteBuffer(int index, int length) {
return this.delegate.toByteBuffer(index, length);
}

@Override
public void toByteBuffer(int srcPos, ByteBuffer dest, int destPos, int length) {
this.delegate.toByteBuffer(srcPos, dest, destPos, length);
}

@Override
public ByteBufferIterator readableByteBuffers() {
ByteBufferIterator delegateIterator = this.delegate.readableByteBuffers();
return new JettyByteBufferIterator(delegateIterator);
}

@Override
public ByteBufferIterator writableByteBuffers() {
ByteBufferIterator delegateIterator = this.delegate.writableByteBuffers();
return new JettyByteBufferIterator(delegateIterator);
}

@Override
public String toString(int index, int length, Charset charset) {
return this.delegate.toString(index, length, charset);
}


private record JettyByteBufferIterator(ByteBufferIterator delegate) implements ByteBufferIterator {

@Override
public void close() {
this.delegate.close();
}

@Override
public boolean hasNext() {
return this.delegate.hasNext();
}

@Override
public ByteBuffer next() {
return this.delegate.next();
}
}
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -90,11 +90,13 @@ public void onWebSocketText(String payload) {

@OnWebSocketMessage
public void onWebSocketBinary(ByteBuffer payload, Callback callback) {
BinaryMessage message = new BinaryMessage(payload, true);
BinaryMessage message = new BinaryMessage(copyByteBuffer(payload), true);
try {
this.webSocketHandler.handleMessage(this.wsSession, message);
callback.succeed();
}
catch (Exception ex) {
callback.fail(ex);
ExceptionWebSocketHandlerDecorator.tryCloseWithError(this.wsSession, ex, logger);
}
}
Expand All @@ -103,16 +105,24 @@ public void onWebSocketBinary(ByteBuffer payload, Callback callback) {
public void onWebSocketFrame(Frame frame, Callback callback) {
if (OpCode.PONG == frame.getOpCode()) {
ByteBuffer payload = frame.getPayload() != null ? frame.getPayload() : EMPTY_PAYLOAD;
PongMessage message = new PongMessage(payload);
PongMessage message = new PongMessage(copyByteBuffer(payload));
try {
this.webSocketHandler.handleMessage(this.wsSession, message);
callback.succeed();
}
catch (Exception ex) {
callback.fail(ex);
ExceptionWebSocketHandlerDecorator.tryCloseWithError(this.wsSession, ex, logger);
}
}
}

private static ByteBuffer copyByteBuffer(ByteBuffer src) {
ByteBuffer dest = ByteBuffer.allocate(src.capacity());
dest.put(0, src, 0, src.remaining());
return dest;
}

@OnWebSocketClose
public void onWebSocketClose(int statusCode, String reason) {
CloseStatus closeStatus = new CloseStatus(statusCode, reason);
Expand Down

0 comments on commit ed172d6

Please sign in to comment.