From 9fdf8dcc9b6b48022cb535578f191ee587f7a311 Mon Sep 17 00:00:00 2001 From: "xeroman.kyn" Date: Mon, 22 Sep 2025 23:42:19 +0900 Subject: [PATCH] Ensure thread-safe ByteBuffer handling in WebSocket sessions Replace direct ByteBuffer usage with asReadOnlyBuffer() in binary message sending to prevent concurrent modification issues when sharing buffers across multiple sessions. Signed-off-by: xeroman.k --- .../adapter/jetty/JettyWebSocketSession.java | 2 +- .../standard/StandardWebSocketSession.java | 2 +- .../jetty/JettyWebSocketSessionTests.java | 50 +++++++++++++++++++ .../StandardWebSocketSessionTests.java | 47 +++++++++++++++++ 4 files changed, 99 insertions(+), 2 deletions(-) diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/adapter/jetty/JettyWebSocketSession.java b/spring-websocket/src/main/java/org/springframework/web/socket/adapter/jetty/JettyWebSocketSession.java index 628404f3f749..d795f9d6f854 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/adapter/jetty/JettyWebSocketSession.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/adapter/jetty/JettyWebSocketSession.java @@ -205,7 +205,7 @@ protected void sendTextMessage(TextMessage message) throws IOException { @Override protected void sendBinaryMessage(BinaryMessage message) throws IOException { - useSession((session, callback) -> session.sendBinary(message.getPayload(), callback)); + useSession((session, callback) -> session.sendBinary(message.getPayload().asReadOnlyBuffer(), callback)); } @Override diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/adapter/standard/StandardWebSocketSession.java b/spring-websocket/src/main/java/org/springframework/web/socket/adapter/standard/StandardWebSocketSession.java index 5d75b369ab32..b9dd0a331d6a 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/adapter/standard/StandardWebSocketSession.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/adapter/standard/StandardWebSocketSession.java @@ -208,7 +208,7 @@ protected void sendTextMessage(TextMessage message) throws IOException { @Override protected void sendBinaryMessage(BinaryMessage message) throws IOException { - getNativeSession().getBasicRemote().sendBinary(message.getPayload(), message.isLast()); + getNativeSession().getBasicRemote().sendBinary(message.getPayload().asReadOnlyBuffer(), message.isLast()); } @Override diff --git a/spring-websocket/src/test/java/org/springframework/web/socket/adapter/jetty/JettyWebSocketSessionTests.java b/spring-websocket/src/test/java/org/springframework/web/socket/adapter/jetty/JettyWebSocketSessionTests.java index 0622571b871d..43f031879697 100644 --- a/spring-websocket/src/test/java/org/springframework/web/socket/adapter/jetty/JettyWebSocketSessionTests.java +++ b/spring-websocket/src/test/java/org/springframework/web/socket/adapter/jetty/JettyWebSocketSessionTests.java @@ -16,17 +16,23 @@ package org.springframework.web.socket.adapter.jetty; +import java.nio.ByteBuffer; import java.util.Map; +import java.util.function.BiConsumer; +import org.eclipse.jetty.websocket.api.Callback; import org.eclipse.jetty.websocket.api.Session; import org.eclipse.jetty.websocket.api.UpgradeRequest; import org.eclipse.jetty.websocket.api.UpgradeResponse; import org.junit.jupiter.api.Test; import org.springframework.core.testfixture.security.TestPrincipal; +import org.springframework.web.socket.BinaryMessage; import static org.assertj.core.api.Assertions.assertThat; +import static org.mockito.ArgumentMatchers.any; import static org.mockito.BDDMockito.given; +import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.reset; import static org.mockito.Mockito.verifyNoMoreInteractions; @@ -117,4 +123,48 @@ void getAcceptedProtocol() { verifyNoMoreInteractions(nativeSession); } + @Test + void binaryMessageWithSharedBufferSendsToMultipleSessions() throws Exception { + byte[] data = {1, 2, 3, 4, 5}; + ByteBuffer sharedBuffer = ByteBuffer.wrap(data); + BinaryMessage message = new BinaryMessage(sharedBuffer); + + ByteBuffer[] captured = new ByteBuffer[2]; + JettyWebSocketSession session1 = createMockSession((buffer, idx) -> captured[0] = buffer); + JettyWebSocketSession session2 = createMockSession((buffer, idx) -> captured[1] = buffer); + + session1.sendMessage(message); + session2.sendMessage(message); + + assertThat(captured[0].array()).isEqualTo(data); + assertThat(captured[1].array()).isEqualTo(data); + + assertThat(sharedBuffer.position()).isEqualTo(0); + } + + private JettyWebSocketSession createMockSession(BiConsumer captureFunction) { + Session mockSession = mock(Session.class); + + given(mockSession.getUpgradeRequest()).willReturn(request); + given(mockSession.getUpgradeResponse()).willReturn(response); + given(mockSession.isOpen()).willReturn(true); + given(request.getUserPrincipal()).willReturn(null); + given(response.getAcceptedSubProtocol()).willReturn(null); + + doAnswer(invocation -> { + ByteBuffer buffer = invocation.getArgument(0); + Callback callback = invocation.getArgument(1); + ByteBuffer copy = ByteBuffer.allocate(buffer.remaining()); + copy.put(buffer); + copy.flip(); + captureFunction.accept(copy, 0); + callback.succeed(); + return null; + }).when(mockSession).sendBinary(any(ByteBuffer.class), any(Callback.class)); + + JettyWebSocketSession session = new JettyWebSocketSession(attributes); + session.initializeNativeSession(mockSession); + return session; + } + } diff --git a/spring-websocket/src/test/java/org/springframework/web/socket/adapter/standard/StandardWebSocketSessionTests.java b/spring-websocket/src/test/java/org/springframework/web/socket/adapter/standard/StandardWebSocketSessionTests.java index c3b251805425..dcfa7d832c95 100644 --- a/spring-websocket/src/test/java/org/springframework/web/socket/adapter/standard/StandardWebSocketSessionTests.java +++ b/spring-websocket/src/test/java/org/springframework/web/socket/adapter/standard/StandardWebSocketSessionTests.java @@ -16,17 +16,24 @@ package org.springframework.web.socket.adapter.standard; +import java.nio.ByteBuffer; import java.util.HashMap; import java.util.Map; +import java.util.function.BiConsumer; +import jakarta.websocket.RemoteEndpoint.Basic; import jakarta.websocket.Session; import org.junit.jupiter.api.Test; import org.springframework.core.testfixture.security.TestPrincipal; import org.springframework.http.HttpHeaders; +import org.springframework.web.socket.BinaryMessage; import static org.assertj.core.api.Assertions.assertThat; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyBoolean; import static org.mockito.BDDMockito.given; +import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.reset; import static org.mockito.Mockito.verifyNoMoreInteractions; @@ -106,4 +113,44 @@ void addAttributesWithNullKeyOrValue() { .hasSize(1).containsEntry("foo", "bar"); } + @Test + void binaryMessageWithSharedBufferSendsToMultipleSessions() throws Exception { + byte[] data = {1, 2, 3, 4, 5}; + ByteBuffer sharedBuffer = ByteBuffer.wrap(data); + BinaryMessage message = new BinaryMessage(sharedBuffer); + + ByteBuffer[] captured = new ByteBuffer[2]; + StandardWebSocketSession session1 = createMockSession((buffer, idx) -> captured[0] = buffer); + StandardWebSocketSession session2 = createMockSession((buffer, idx) -> captured[1] = buffer); + + session1.sendMessage(message); + session2.sendMessage(message); + + assertThat(captured[0].array()).isEqualTo(data); + assertThat(captured[1].array()).isEqualTo(data); + + assertThat(sharedBuffer.position()).isEqualTo(0); + } + + private StandardWebSocketSession createMockSession(BiConsumer captureFunction) throws Exception { + Session nativeSession = mock(Session.class); + Basic basicRemote = mock(Basic.class); + + given(nativeSession.getBasicRemote()).willReturn(basicRemote); + given(nativeSession.isOpen()).willReturn(true); + + doAnswer(invocation -> { + ByteBuffer buffer = invocation.getArgument(0); + ByteBuffer copy = ByteBuffer.allocate(buffer.remaining()); + copy.put(buffer); + copy.flip(); + captureFunction.accept(copy, 0); + return null; + }).when(basicRemote).sendBinary(any(ByteBuffer.class), anyBoolean()); + + StandardWebSocketSession session = new StandardWebSocketSession(this.headers, this.attributes, null, null); + session.initializeNativeSession(nativeSession); + return session; + } + }