diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/adapter/jetty/JettyWebSocketSession.java b/spring-websocket/src/main/java/org/springframework/web/socket/adapter/jetty/JettyWebSocketSession.java index 31cfc1addeba..5fdd6f0d0a3a 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/adapter/jetty/JettyWebSocketSession.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/adapter/jetty/JettyWebSocketSession.java @@ -19,6 +19,7 @@ import java.io.IOException; import java.lang.reflect.Method; import java.net.InetSocketAddress; +import java.net.SocketAddress; import java.net.URI; import java.security.Principal; import java.util.ArrayList; @@ -143,13 +144,13 @@ public Principal getPrincipal() { @Override public InetSocketAddress getLocalAddress() { checkNativeSessionInitialized(); - return getNativeSession().getLocalAddress(); + return this.sessionHelper.getLocalAddress(getNativeSession()); } @Override public InetSocketAddress getRemoteAddress() { checkNativeSessionInitialized(); - return getNativeSession().getRemoteAddress(); + return this.sessionHelper.getRemoteAddress(getNativeSession()); } /** @@ -248,6 +249,11 @@ private interface SessionHelper { int getTextMessageSizeLimit(Session session); int getBinaryMessageSizeLimit(Session session); + + InetSocketAddress getRemoteAddress(Session session); + + InetSocketAddress getLocalAddress(Session session); + } @@ -275,6 +281,16 @@ public int getTextMessageSizeLimit(Session session) { public int getBinaryMessageSizeLimit(Session session) { return session.getPolicy().getMaxBinaryMessageSize(); } + + @Override + public InetSocketAddress getRemoteAddress(Session session) { + return session.getRemoteAddress(); + } + + @Override + public InetSocketAddress getLocalAddress(Session session) { + return session.getLocalAddress(); + } } @@ -284,11 +300,17 @@ private static class Jetty10SessionHelper implements SessionHelper { private static final Method getBinaryMessageSizeLimitMethod; + private static final Method getRemoteAddressMethod; + + private static final Method getLocalAddressMethod; + static { try { Class type = loader.loadClass("org.eclipse.jetty.websocket.api.WebSocketPolicy"); getTextMessageSizeLimitMethod = type.getMethod("getMaxTextMessageSize"); getBinaryMessageSizeLimitMethod = type.getMethod("getMaxBinaryMessageSize"); + getRemoteAddressMethod = type.getMethod("getRemoteAddress"); + getLocalAddressMethod = type.getMethod("getLocalAddress"); } catch (ClassNotFoundException | NoSuchMethodException ex) { throw new IllegalStateException("No compatible Jetty version found", ex); @@ -321,6 +343,22 @@ public int getBinaryMessageSizeLimit(Session session) { Assert.state(result <= Integer.MAX_VALUE, "binaryMessageSizeLimit is larger than Integer.MAX_VALUE"); return (int) result; } + + @Override + @SuppressWarnings("ConstantConditions") + public InetSocketAddress getRemoteAddress(Session session) { + SocketAddress address = (SocketAddress) ReflectionUtils.invokeMethod(getRemoteAddressMethod, session); + Assert.isInstanceOf(InetSocketAddress.class, address); + return (InetSocketAddress) address; + } + + @Override + @SuppressWarnings("ConstantConditions") + public InetSocketAddress getLocalAddress(Session session) { + SocketAddress address = (SocketAddress) ReflectionUtils.invokeMethod(getLocalAddressMethod, session); + Assert.isInstanceOf(InetSocketAddress.class, address); + return (InetSocketAddress) address; + } } }