diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/handler/ConcurrentWebSocketSessionDecorator.java b/spring-websocket/src/main/java/org/springframework/web/socket/handler/ConcurrentWebSocketSessionDecorator.java index 0113f89b0bf6..304f10c4b608 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/handler/ConcurrentWebSocketSessionDecorator.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/handler/ConcurrentWebSocketSessionDecorator.java @@ -17,6 +17,7 @@ package org.springframework.web.socket.handler; import java.io.IOException; +import java.nio.ByteBuffer; import java.util.Queue; import java.util.concurrent.LinkedBlockingQueue; import java.util.concurrent.atomic.AtomicInteger; @@ -28,6 +29,7 @@ import org.apache.commons.logging.LogFactory; import org.jspecify.annotations.Nullable; +import org.springframework.web.socket.BinaryMessage; import org.springframework.web.socket.CloseStatus; import org.springframework.web.socket.WebSocketMessage; import org.springframework.web.socket.WebSocketSession; @@ -43,6 +45,7 @@ * * @author Rossen Stoyanchev * @author Juergen Hoeller + * @author xeroman.k * @since 4.0.3 */ public class ConcurrentWebSocketSessionDecorator extends WebSocketSessionDecorator { @@ -170,7 +173,7 @@ public void sendMessage(WebSocketMessage message) throws IOException { if (!tryFlushMessageBuffer()) { if (logger.isTraceEnabled()) { logger.trace(String.format("Another send already in progress: " + - "session id '%s':, \"in-progress\" send time %d (ms), buffer size %d bytes", + "session id '%s':, \"in-progress\" send time %d (ms), buffer size %d bytes", getId(), getTimeSinceSendStarted(), getBufferSize())); } checkSessionLimits(); @@ -194,7 +197,7 @@ private boolean tryFlushMessageBuffer() throws IOException { } this.bufferSize.addAndGet(-message.getPayloadLength()); this.sendStartTime = System.currentTimeMillis(); - getDelegate().sendMessage(message); + getDelegate().sendMessage(prepareMessage(message)); this.sendStartTime = 0; } } @@ -207,6 +210,14 @@ private boolean tryFlushMessageBuffer() throws IOException { return false; } + private WebSocketMessage prepareMessage(WebSocketMessage message) { + if (message instanceof BinaryMessage) { + ByteBuffer payload = ((BinaryMessage) message).getPayload(); + return new BinaryMessage(payload.duplicate(), message.isLast()); + } + return message; + } + private void checkSessionLimits() { if (!shouldNotSend() && this.closeLock.tryLock()) { try { @@ -238,7 +249,7 @@ else if (getBufferSize() > getBufferSizeLimit()) { } default -> // Should never happen.. - throw new IllegalStateException("Unexpected OverflowStrategy: " + this.overflowStrategy); + throw new IllegalStateException("Unexpected OverflowStrategy: " + this.overflowStrategy); } } } diff --git a/spring-websocket/src/test/java/org/springframework/web/socket/handler/ConcurrentWebSocketSessionDecoratorTests.java b/spring-websocket/src/test/java/org/springframework/web/socket/handler/ConcurrentWebSocketSessionDecoratorTests.java index 85678fce779a..a443efad4b70 100644 --- a/spring-websocket/src/test/java/org/springframework/web/socket/handler/ConcurrentWebSocketSessionDecoratorTests.java +++ b/spring-websocket/src/test/java/org/springframework/web/socket/handler/ConcurrentWebSocketSessionDecoratorTests.java @@ -17,14 +17,21 @@ package org.springframework.web.socket.handler; import java.io.IOException; +import java.nio.ByteBuffer; +import java.util.ArrayList; +import java.util.List; import java.util.concurrent.CountDownLatch; +import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicInteger; import org.junit.jupiter.api.Test; +import org.springframework.web.socket.BinaryMessage; import org.springframework.web.socket.CloseStatus; import org.springframework.web.socket.TextMessage; +import org.springframework.web.socket.WebSocketMessage; import org.springframework.web.socket.WebSocketSession; import org.springframework.web.socket.handler.ConcurrentWebSocketSessionDecorator.OverflowStrategy; @@ -35,6 +42,7 @@ * Tests for {@link ConcurrentWebSocketSessionDecorator}. * * @author Rossen Stoyanchev + * @author xeroman.k */ class ConcurrentWebSocketSessionDecoratorTests { @@ -98,9 +106,9 @@ void sendTimeLimitExceeded() throws InterruptedException { TextMessage payload = new TextMessage("payload"); assertThatExceptionOfType(SessionLimitExceededException.class).isThrownBy(() -> - decorator.sendMessage(payload)) - .withMessageMatching("Send time [\\d]+ \\(ms\\) for session '123' exceeded the allowed limit 100") - .satisfies(ex -> assertThat(ex.getStatus()).isEqualTo(CloseStatus.SESSION_NOT_RELIABLE)); + decorator.sendMessage(payload)) + .withMessageMatching("Send time [\\d]+ \\(ms\\) for session '123' exceeded the allowed limit 100") + .satisfies(ex -> assertThat(ex.getStatus()).isEqualTo(CloseStatus.SESSION_NOT_RELIABLE)); } @Test @@ -123,9 +131,9 @@ void sendBufferSizeExceeded() throws IOException, InterruptedException { assertThat(session.isOpen()).isTrue(); assertThatExceptionOfType(SessionLimitExceededException.class).isThrownBy(() -> - decorator.sendMessage(message)) - .withMessageMatching("Buffer size [\\d]+ bytes for session '123' exceeds the allowed limit 1024") - .satisfies(ex -> assertThat(ex.getStatus()).isEqualTo(CloseStatus.SESSION_NOT_RELIABLE)); + decorator.sendMessage(message)) + .withMessageMatching("Buffer size [\\d]+ bytes for session '123' exceeds the allowed limit 1024") + .satisfies(ex -> assertThat(ex.getStatus()).isEqualTo(CloseStatus.SESSION_NOT_RELIABLE)); } @Test // SPR-17140 @@ -226,4 +234,119 @@ void configuredProperties() { assertThat(sessionDecorator.getOverflowStrategy()).isEqualTo(OverflowStrategy.DROP); } + @Test + void concurrentBinaryMessageSharingAcrossSessions() throws Exception { + byte[] originalData = new byte[100]; + for (int i = 0; i < originalData.length; i++) { + originalData[i] = (byte) i; + } + ByteBuffer buffer = ByteBuffer.wrap(originalData); + BinaryMessage sharedMessage = new BinaryMessage(buffer); + + int sessionCount = 5; + int messagesPerSession = 3; + CountDownLatch startLatch = new CountDownLatch(1); + CountDownLatch completeLatch = new CountDownLatch(sessionCount * messagesPerSession); + AtomicInteger corruptedBuffers = new AtomicInteger(0); + + List sessions = new ArrayList<>(); + List decorators = new ArrayList<>(); + + for (int i = 0; i < sessionCount; i++) { + TestBinaryMessageCapturingSession session = new TestBinaryMessageCapturingSession(); + session.setOpen(true); + session.setId("session-" + i); + sessions.add(session); + + ConcurrentWebSocketSessionDecorator decorator = + new ConcurrentWebSocketSessionDecorator(session, 10000, 10240); + decorators.add(decorator); + } + + ExecutorService executor = Executors.newFixedThreadPool(sessionCount * messagesPerSession); + + try { + for (ConcurrentWebSocketSessionDecorator decorator : decorators) { + for (int j = 0; j < messagesPerSession; j++) { + executor.submit(() -> { + try { + startLatch.await(); + decorator.sendMessage(sharedMessage); + } catch (Exception e) { + e.printStackTrace(); + } finally { + completeLatch.countDown(); + } + }); + } + } + + startLatch.countDown(); + assertThat(completeLatch.await(5, TimeUnit.SECONDS)).isTrue(); + + for (TestBinaryMessageCapturingSession session : sessions) { + List capturedBuffers = session.getCapturedBuffers(); + + for (ByteBuffer capturedBuffer : capturedBuffers) { + byte[] capturedData = new byte[capturedBuffer.remaining()]; + capturedBuffer.get(capturedData); + + boolean isCorrupted = false; + if (capturedData.length != originalData.length) { + isCorrupted = true; + } else { + for (int j = 0; j < originalData.length; j++) { + if (capturedData[j] != originalData[j]) { + isCorrupted = true; + break; + } + } + } + + if (isCorrupted) { + corruptedBuffers.incrementAndGet(); + } + } + } + + assertThat(corruptedBuffers.get()) + .as("No ByteBuffer corruption should occur with duplicate() fix") + .isEqualTo(0); + } finally { + executor.shutdown(); + } + } + + static class TestBinaryMessageCapturingSession extends TestWebSocketSession { + private final List capturedBuffers = new ArrayList<>(); + + @Override + public void sendMessage(WebSocketMessage message) throws IOException { + if (message instanceof BinaryMessage) { + ByteBuffer payload = ((BinaryMessage) message).getPayload(); + ByteBuffer captured = ByteBuffer.allocate(payload.remaining()); + + while (payload.hasRemaining()) { + captured.put(payload.get()); + } + captured.flip(); + + synchronized (capturedBuffers) { + capturedBuffers.add(captured); + } + + try { + Thread.sleep(1); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + } + } + super.sendMessage(message); + } + + public synchronized List getCapturedBuffers() { + return new ArrayList<>(capturedBuffers); + } + } + }