Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 22 additions & 6 deletions src/main/java/com/rabbitmq/client/ConnectionFactory.java
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@
import com.rabbitmq.client.impl.recovery.AutorecoveringConnection;
import com.rabbitmq.client.impl.recovery.RetryHandler;
import com.rabbitmq.client.impl.recovery.TopologyRecoveryFilter;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import javax.net.SocketFactory;
import javax.net.ssl.SSLContext;
Expand All @@ -47,6 +49,8 @@
*/
public class ConnectionFactory implements Cloneable {

private static final Logger LOGGER = LoggerFactory.getLogger(ConnectionFactory.class);

private static final int MAX_UNSIGNED_SHORT = 65535;

/** Default user name */
Expand Down Expand Up @@ -393,10 +397,11 @@ public int getRequestedChannelMax() {
* @param requestedChannelMax initially requested maximum channel number; zero for unlimited
*/
public void setRequestedChannelMax(int requestedChannelMax) {
if (requestedChannelMax < 0 || requestedChannelMax > MAX_UNSIGNED_SHORT) {
throw new IllegalArgumentException("Requested channel max must be between 0 and " + MAX_UNSIGNED_SHORT);
this.requestedChannelMax = ensureUnsignedShort(requestedChannelMax);
if (this.requestedChannelMax != requestedChannelMax) {
LOGGER.warn("Requested channel max must be between 0 and {}, value has been set to {} instead of {}",
MAX_UNSIGNED_SHORT, this.requestedChannelMax, requestedChannelMax);
}
this.requestedChannelMax = requestedChannelMax;
}

/**
Expand Down Expand Up @@ -492,10 +497,11 @@ public int getShutdownTimeout() {
* @see <a href="https://rabbitmq.com/heartbeats.html">RabbitMQ Heartbeats Guide</a>
*/
public void setRequestedHeartbeat(int requestedHeartbeat) {
if (requestedHeartbeat < 0 || requestedHeartbeat > MAX_UNSIGNED_SHORT) {
throw new IllegalArgumentException("Requested heartbeat must be between 0 and " + MAX_UNSIGNED_SHORT);
this.requestedHeartbeat = ensureUnsignedShort(requestedHeartbeat);
if (this.requestedHeartbeat != requestedHeartbeat) {
LOGGER.warn("Requested heartbeat must be between 0 and {}, value has been set to {} instead of {}",
MAX_UNSIGNED_SHORT, this.requestedHeartbeat, requestedHeartbeat);
}
this.requestedHeartbeat = requestedHeartbeat;
}

/**
Expand Down Expand Up @@ -1574,4 +1580,14 @@ public void setTopologyRecoveryRetryHandler(RetryHandler topologyRecoveryRetryHa
public void setTrafficListener(TrafficListener trafficListener) {
this.trafficListener = trafficListener;
}

public static int ensureUnsignedShort(int value) {
if (value < 0) {
return 0;
} else if (value > MAX_UNSIGNED_SHORT) {
return MAX_UNSIGNED_SHORT;
} else {
return value;
}
}
}
18 changes: 12 additions & 6 deletions src/main/java/com/rabbitmq/client/impl/AMQConnection.java
Original file line number Diff line number Diff line change
Expand Up @@ -400,12 +400,15 @@ public void start()
}

try {
int channelMax =
int negotiatedChannelMax =
negotiateChannelMax(this.requestedChannelMax,
connTune.getChannelMax());

if (!checkUnsignedShort(channelMax)) {
throw new IllegalArgumentException("Negotiated channel max must be between 0 and " + MAX_UNSIGNED_SHORT + ": " + channelMax);
int channelMax = ConnectionFactory.ensureUnsignedShort(negotiatedChannelMax);

if (channelMax != negotiatedChannelMax) {
LOGGER.warn("Channel max must be between 0 and {}, value has been set to {} instead of {}",
MAX_UNSIGNED_SHORT, channelMax, negotiatedChannelMax);
}

_channelManager = instantiateChannelManager(channelMax, threadFactory);
Expand All @@ -415,12 +418,15 @@ public void start()
connTune.getFrameMax());
this._frameMax = frameMax;

int heartbeat =
int negotiatedHeartbeat =
negotiatedMaxValue(this.requestedHeartbeat,
connTune.getHeartbeat());

if (!checkUnsignedShort(heartbeat)) {
throw new IllegalArgumentException("Negotiated heartbeat must be between 0 and " + MAX_UNSIGNED_SHORT + ": " + heartbeat);
int heartbeat = ConnectionFactory.ensureUnsignedShort(negotiatedHeartbeat);

if (heartbeat != negotiatedHeartbeat) {
LOGGER.warn("Heartbeat must be between 0 and {}, value has been set to {} instead of {}",
MAX_UNSIGNED_SHORT, heartbeat, negotiatedHeartbeat);
}

setHeartbeat(heartbeat);
Expand Down
8 changes: 5 additions & 3 deletions src/main/java/com/rabbitmq/client/impl/ChannelN.java
Original file line number Diff line number Diff line change
Expand Up @@ -642,10 +642,12 @@ public AMQCommand transformReply(AMQCommand command) {
public void basicQos(int prefetchSize, int prefetchCount, boolean global)
throws IOException
{
if (prefetchCount < 0 || prefetchCount > MAX_UNSIGNED_SHORT) {
throw new IllegalArgumentException("Prefetch count must be between 0 and " + MAX_UNSIGNED_SHORT);
int unsignedShortPrefetchCount = ConnectionFactory.ensureUnsignedShort(prefetchCount);
if (unsignedShortPrefetchCount != prefetchCount) {
LOGGER.warn("Prefetch count must be between 0 and {}, value has been set to {} instead of {}",
MAX_UNSIGNED_SHORT, unsignedShortPrefetchCount, prefetchCount);
}
exnWrappingRpc(new Basic.Qos(prefetchSize, prefetchCount, global));
exnWrappingRpc(new Basic.Qos(prefetchSize, unsignedShortPrefetchCount, global));
}

/** Public API - {@inheritDoc} */
Expand Down
32 changes: 27 additions & 5 deletions src/test/java/com/rabbitmq/client/test/ChannelNTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,10 @@
import java.io.IOException;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.atomic.AtomicReference;
import java.util.stream.Stream;

import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatThrownBy;

public class ChannelNTest {
Expand Down Expand Up @@ -64,23 +66,43 @@ public void callingBasicCancelForUnknownConsumerThrowsException() throws Excepti
@Test
public void qosShouldBeUnsignedShort() {
AMQConnection connection = Mockito.mock(AMQConnection.class);
ChannelN channel = new ChannelN(connection, 1, consumerWorkService);
AtomicReference<com.rabbitmq.client.AMQP.Basic.Qos> qosMethod = new AtomicReference<>();
ChannelN channel = new ChannelN(connection, 1, consumerWorkService) {
@Override
public AMQCommand exnWrappingRpc(Method m) {
qosMethod.set((com.rabbitmq.client.AMQP.Basic.Qos) m);
return null;
}
};
class TestConfig {
int value;
Consumer call;
int expected;

public TestConfig(int value, Consumer call) {
public TestConfig(int value, Consumer call, int expected) {
this.value = value;
this.call = call;
this.expected = expected;
}
}
Consumer qos = value -> channel.basicQos(value);
Consumer qosGlobal = value -> channel.basicQos(value, true);
Consumer qosPrefetchSize = value -> channel.basicQos(10, value, true);
Stream.of(
new TestConfig(-1, qos), new TestConfig(65536, qos)
).flatMap(config -> Stream.of(config, new TestConfig(config.value, qosGlobal), new TestConfig(config.value, qosPrefetchSize)))
.forEach(config -> assertThatThrownBy(() -> config.call.apply(config.value)).isInstanceOf(IllegalArgumentException.class));
new TestConfig(-1, qos, 0), new TestConfig(65536, qos, 65535),
new TestConfig(10, qos, 10), new TestConfig(0, qos, 0)
).flatMap(config -> Stream.of(config, new TestConfig(config.value, qosGlobal, config.expected), new TestConfig(config.value, qosPrefetchSize, config.expected)))
.forEach(config -> {
try {
assertThat(qosMethod.get()).isNull();
config.call.apply(config.value);
assertThat(qosMethod.get()).isNotNull();
assertThat(qosMethod.get().getPrefetchCount()).isEqualTo(config.expected);
qosMethod.set(null);
} catch (Exception e) {
e.printStackTrace();
}
});
}

interface Consumer {
Expand Down
35 changes: 18 additions & 17 deletions src/test/java/com/rabbitmq/client/test/ConnectionFactoryTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,10 @@
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicReference;
import java.util.function.Consumer;
import java.util.function.Supplier;
import java.util.stream.Stream;

import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatThrownBy;
import static org.mockito.Mockito.*;

public class ConnectionFactoryTest {
Expand Down Expand Up @@ -164,33 +164,34 @@ protected synchronized FrameHandlerFactory createFrameHandlerFactory() {
public void heartbeatAndChannelMaxMustBeUnsignedShorts() {
class TestConfig {
int value;
Consumer<Integer> call;
boolean expectException;
Supplier<Integer> getCall;
Consumer<Integer> setCall;
int expected;

public TestConfig(int value, Consumer<Integer> call, boolean expectException) {
public TestConfig(int value, Supplier<Integer> getCall, Consumer<Integer> setCall, int expected) {
this.value = value;
this.call = call;
this.expectException = expectException;
this.getCall = getCall;
this.setCall = setCall;
this.expected = expected;
}
}

ConnectionFactory cf = new ConnectionFactory();
Supplier<Integer> getHeartbeart = () -> cf.getRequestedHeartbeat();
Consumer<Integer> setHeartbeat = cf::setRequestedHeartbeat;
Supplier<Integer> getChannelMax = () -> cf.getRequestedChannelMax();
Consumer<Integer> setChannelMax = cf::setRequestedChannelMax;

Stream.of(
new TestConfig(0, setHeartbeat, false),
new TestConfig(10, setHeartbeat, false),
new TestConfig(65535, setHeartbeat, false),
new TestConfig(-1, setHeartbeat, true),
new TestConfig(65536, setHeartbeat, true))
.flatMap(config -> Stream.of(config, new TestConfig(config.value, setChannelMax, config.expectException)))
new TestConfig(0, getHeartbeart, setHeartbeat, 0),
new TestConfig(10, getHeartbeart, setHeartbeat, 10),
new TestConfig(65535, getHeartbeart, setHeartbeat, 65535),
new TestConfig(-1, getHeartbeart, setHeartbeat, 0),
new TestConfig(65536, getHeartbeart, setHeartbeat, 65535))
.flatMap(config -> Stream.of(config, new TestConfig(config.value, getChannelMax, setChannelMax, config.expected)))
.forEach(config -> {
if (config.expectException) {
assertThatThrownBy(() -> config.call.accept(config.value)).isInstanceOf(IllegalArgumentException.class);
} else {
config.call.accept(config.value);
}
config.setCall.accept(config.value);
assertThat(config.getCall.get()).isEqualTo(config.expected);
});

}
Expand Down