From f2ea8622effd331b5da86b8d84972f85827e52dd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Arnaud=20Cogolu=C3=A8gnes?= Date: Thu, 2 Apr 2020 11:03:06 +0200 Subject: [PATCH] Make sure qos, heartbeat, max channel are unsigned shorts Sets the value to 0 or 65535 and issues a warning if it is out of range. Fixes #642 --- .../rabbitmq/client/ConnectionFactory.java | 28 +++++++++++---- .../rabbitmq/client/impl/AMQConnection.java | 18 ++++++---- .../com/rabbitmq/client/impl/ChannelN.java | 8 +++-- .../rabbitmq/client/test/ChannelNTest.java | 32 ++++++++++++++--- .../client/test/ConnectionFactoryTest.java | 35 ++++++++++--------- 5 files changed, 84 insertions(+), 37 deletions(-) diff --git a/src/main/java/com/rabbitmq/client/ConnectionFactory.java b/src/main/java/com/rabbitmq/client/ConnectionFactory.java index 7fdcdb1324..71286d3b3b 100644 --- a/src/main/java/com/rabbitmq/client/ConnectionFactory.java +++ b/src/main/java/com/rabbitmq/client/ConnectionFactory.java @@ -21,6 +21,8 @@ import com.rabbitmq.client.impl.recovery.AutorecoveringConnection; import com.rabbitmq.client.impl.recovery.RetryHandler; import com.rabbitmq.client.impl.recovery.TopologyRecoveryFilter; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; import javax.net.SocketFactory; import javax.net.ssl.SSLContext; @@ -47,6 +49,8 @@ */ public class ConnectionFactory implements Cloneable { + private static final Logger LOGGER = LoggerFactory.getLogger(ConnectionFactory.class); + private static final int MAX_UNSIGNED_SHORT = 65535; /** Default user name */ @@ -393,10 +397,11 @@ public int getRequestedChannelMax() { * @param requestedChannelMax initially requested maximum channel number; zero for unlimited */ public void setRequestedChannelMax(int requestedChannelMax) { - if (requestedChannelMax < 0 || requestedChannelMax > MAX_UNSIGNED_SHORT) { - throw new IllegalArgumentException("Requested channel max must be between 0 and " + MAX_UNSIGNED_SHORT); + this.requestedChannelMax = ensureUnsignedShort(requestedChannelMax); + if (this.requestedChannelMax != requestedChannelMax) { + LOGGER.warn("Requested channel max must be between 0 and {}, value has been set to {} instead of {}", + MAX_UNSIGNED_SHORT, this.requestedChannelMax, requestedChannelMax); } - this.requestedChannelMax = requestedChannelMax; } /** @@ -492,10 +497,11 @@ public int getShutdownTimeout() { * @see RabbitMQ Heartbeats Guide */ public void setRequestedHeartbeat(int requestedHeartbeat) { - if (requestedHeartbeat < 0 || requestedHeartbeat > MAX_UNSIGNED_SHORT) { - throw new IllegalArgumentException("Requested heartbeat must be between 0 and " + MAX_UNSIGNED_SHORT); + this.requestedHeartbeat = ensureUnsignedShort(requestedHeartbeat); + if (this.requestedHeartbeat != requestedHeartbeat) { + LOGGER.warn("Requested heartbeat must be between 0 and {}, value has been set to {} instead of {}", + MAX_UNSIGNED_SHORT, this.requestedHeartbeat, requestedHeartbeat); } - this.requestedHeartbeat = requestedHeartbeat; } /** @@ -1574,4 +1580,14 @@ public void setTopologyRecoveryRetryHandler(RetryHandler topologyRecoveryRetryHa public void setTrafficListener(TrafficListener trafficListener) { this.trafficListener = trafficListener; } + + public static int ensureUnsignedShort(int value) { + if (value < 0) { + return 0; + } else if (value > MAX_UNSIGNED_SHORT) { + return MAX_UNSIGNED_SHORT; + } else { + return value; + } + } } diff --git a/src/main/java/com/rabbitmq/client/impl/AMQConnection.java b/src/main/java/com/rabbitmq/client/impl/AMQConnection.java index b140788a84..d99784f933 100644 --- a/src/main/java/com/rabbitmq/client/impl/AMQConnection.java +++ b/src/main/java/com/rabbitmq/client/impl/AMQConnection.java @@ -400,12 +400,15 @@ public void start() } try { - int channelMax = + int negotiatedChannelMax = negotiateChannelMax(this.requestedChannelMax, connTune.getChannelMax()); - if (!checkUnsignedShort(channelMax)) { - throw new IllegalArgumentException("Negotiated channel max must be between 0 and " + MAX_UNSIGNED_SHORT + ": " + channelMax); + int channelMax = ConnectionFactory.ensureUnsignedShort(negotiatedChannelMax); + + if (channelMax != negotiatedChannelMax) { + LOGGER.warn("Channel max must be between 0 and {}, value has been set to {} instead of {}", + MAX_UNSIGNED_SHORT, channelMax, negotiatedChannelMax); } _channelManager = instantiateChannelManager(channelMax, threadFactory); @@ -415,12 +418,15 @@ public void start() connTune.getFrameMax()); this._frameMax = frameMax; - int heartbeat = + int negotiatedHeartbeat = negotiatedMaxValue(this.requestedHeartbeat, connTune.getHeartbeat()); - if (!checkUnsignedShort(heartbeat)) { - throw new IllegalArgumentException("Negotiated heartbeat must be between 0 and " + MAX_UNSIGNED_SHORT + ": " + heartbeat); + int heartbeat = ConnectionFactory.ensureUnsignedShort(negotiatedHeartbeat); + + if (heartbeat != negotiatedHeartbeat) { + LOGGER.warn("Heartbeat must be between 0 and {}, value has been set to {} instead of {}", + MAX_UNSIGNED_SHORT, heartbeat, negotiatedHeartbeat); } setHeartbeat(heartbeat); diff --git a/src/main/java/com/rabbitmq/client/impl/ChannelN.java b/src/main/java/com/rabbitmq/client/impl/ChannelN.java index 94da03e7bc..2bbc37a3ca 100644 --- a/src/main/java/com/rabbitmq/client/impl/ChannelN.java +++ b/src/main/java/com/rabbitmq/client/impl/ChannelN.java @@ -642,10 +642,12 @@ public AMQCommand transformReply(AMQCommand command) { public void basicQos(int prefetchSize, int prefetchCount, boolean global) throws IOException { - if (prefetchCount < 0 || prefetchCount > MAX_UNSIGNED_SHORT) { - throw new IllegalArgumentException("Prefetch count must be between 0 and " + MAX_UNSIGNED_SHORT); + int unsignedShortPrefetchCount = ConnectionFactory.ensureUnsignedShort(prefetchCount); + if (unsignedShortPrefetchCount != prefetchCount) { + LOGGER.warn("Prefetch count must be between 0 and {}, value has been set to {} instead of {}", + MAX_UNSIGNED_SHORT, unsignedShortPrefetchCount, prefetchCount); } - exnWrappingRpc(new Basic.Qos(prefetchSize, prefetchCount, global)); + exnWrappingRpc(new Basic.Qos(prefetchSize, unsignedShortPrefetchCount, global)); } /** Public API - {@inheritDoc} */ diff --git a/src/test/java/com/rabbitmq/client/test/ChannelNTest.java b/src/test/java/com/rabbitmq/client/test/ChannelNTest.java index 194f086ef8..c955c28071 100644 --- a/src/test/java/com/rabbitmq/client/test/ChannelNTest.java +++ b/src/test/java/com/rabbitmq/client/test/ChannelNTest.java @@ -25,8 +25,10 @@ import java.io.IOException; import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; +import java.util.concurrent.atomic.AtomicReference; import java.util.stream.Stream; +import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatThrownBy; public class ChannelNTest { @@ -64,23 +66,43 @@ public void callingBasicCancelForUnknownConsumerThrowsException() throws Excepti @Test public void qosShouldBeUnsignedShort() { AMQConnection connection = Mockito.mock(AMQConnection.class); - ChannelN channel = new ChannelN(connection, 1, consumerWorkService); + AtomicReference qosMethod = new AtomicReference<>(); + ChannelN channel = new ChannelN(connection, 1, consumerWorkService) { + @Override + public AMQCommand exnWrappingRpc(Method m) { + qosMethod.set((com.rabbitmq.client.AMQP.Basic.Qos) m); + return null; + } + }; class TestConfig { int value; Consumer call; + int expected; - public TestConfig(int value, Consumer call) { + public TestConfig(int value, Consumer call, int expected) { this.value = value; this.call = call; + this.expected = expected; } } Consumer qos = value -> channel.basicQos(value); Consumer qosGlobal = value -> channel.basicQos(value, true); Consumer qosPrefetchSize = value -> channel.basicQos(10, value, true); Stream.of( - new TestConfig(-1, qos), new TestConfig(65536, qos) - ).flatMap(config -> Stream.of(config, new TestConfig(config.value, qosGlobal), new TestConfig(config.value, qosPrefetchSize))) - .forEach(config -> assertThatThrownBy(() -> config.call.apply(config.value)).isInstanceOf(IllegalArgumentException.class)); + new TestConfig(-1, qos, 0), new TestConfig(65536, qos, 65535), + new TestConfig(10, qos, 10), new TestConfig(0, qos, 0) + ).flatMap(config -> Stream.of(config, new TestConfig(config.value, qosGlobal, config.expected), new TestConfig(config.value, qosPrefetchSize, config.expected))) + .forEach(config -> { + try { + assertThat(qosMethod.get()).isNull(); + config.call.apply(config.value); + assertThat(qosMethod.get()).isNotNull(); + assertThat(qosMethod.get().getPrefetchCount()).isEqualTo(config.expected); + qosMethod.set(null); + } catch (Exception e) { + e.printStackTrace(); + } + }); } interface Consumer { diff --git a/src/test/java/com/rabbitmq/client/test/ConnectionFactoryTest.java b/src/test/java/com/rabbitmq/client/test/ConnectionFactoryTest.java index ff1e7e4b80..1846078ff2 100644 --- a/src/test/java/com/rabbitmq/client/test/ConnectionFactoryTest.java +++ b/src/test/java/com/rabbitmq/client/test/ConnectionFactoryTest.java @@ -27,10 +27,10 @@ import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicReference; import java.util.function.Consumer; +import java.util.function.Supplier; import java.util.stream.Stream; import static org.assertj.core.api.Assertions.assertThat; -import static org.assertj.core.api.Assertions.assertThatThrownBy; import static org.mockito.Mockito.*; public class ConnectionFactoryTest { @@ -164,33 +164,34 @@ protected synchronized FrameHandlerFactory createFrameHandlerFactory() { public void heartbeatAndChannelMaxMustBeUnsignedShorts() { class TestConfig { int value; - Consumer call; - boolean expectException; + Supplier getCall; + Consumer setCall; + int expected; - public TestConfig(int value, Consumer call, boolean expectException) { + public TestConfig(int value, Supplier getCall, Consumer setCall, int expected) { this.value = value; - this.call = call; - this.expectException = expectException; + this.getCall = getCall; + this.setCall = setCall; + this.expected = expected; } } ConnectionFactory cf = new ConnectionFactory(); + Supplier getHeartbeart = () -> cf.getRequestedHeartbeat(); Consumer setHeartbeat = cf::setRequestedHeartbeat; + Supplier getChannelMax = () -> cf.getRequestedChannelMax(); Consumer setChannelMax = cf::setRequestedChannelMax; Stream.of( - new TestConfig(0, setHeartbeat, false), - new TestConfig(10, setHeartbeat, false), - new TestConfig(65535, setHeartbeat, false), - new TestConfig(-1, setHeartbeat, true), - new TestConfig(65536, setHeartbeat, true)) - .flatMap(config -> Stream.of(config, new TestConfig(config.value, setChannelMax, config.expectException))) + new TestConfig(0, getHeartbeart, setHeartbeat, 0), + new TestConfig(10, getHeartbeart, setHeartbeat, 10), + new TestConfig(65535, getHeartbeart, setHeartbeat, 65535), + new TestConfig(-1, getHeartbeart, setHeartbeat, 0), + new TestConfig(65536, getHeartbeart, setHeartbeat, 65535)) + .flatMap(config -> Stream.of(config, new TestConfig(config.value, getChannelMax, setChannelMax, config.expected))) .forEach(config -> { - if (config.expectException) { - assertThatThrownBy(() -> config.call.accept(config.value)).isInstanceOf(IllegalArgumentException.class); - } else { - config.call.accept(config.value); - } + config.setCall.accept(config.value); + assertThat(config.getCall.get()).isEqualTo(config.expected); }); }