diff --git a/integration-tests/src/test/java/com/datastax/oss/driver/core/pool/AdvancedShardAwarenessIT.java b/integration-tests/src/test/java/com/datastax/oss/driver/core/pool/AdvancedShardAwarenessIT.java index 7886442a70d..aeda4a02c7b 100644 --- a/integration-tests/src/test/java/com/datastax/oss/driver/core/pool/AdvancedShardAwarenessIT.java +++ b/integration-tests/src/test/java/com/datastax/oss/driver/core/pool/AdvancedShardAwarenessIT.java @@ -10,36 +10,43 @@ import com.datastax.oss.driver.api.core.CqlSessionBuilder; import com.datastax.oss.driver.api.core.config.DefaultDriverOption; import com.datastax.oss.driver.api.core.config.DriverConfigLoader; -import com.datastax.oss.driver.api.core.session.Session; +import com.datastax.oss.driver.api.core.metadata.Node; import com.datastax.oss.driver.api.testinfra.ScyllaOnly; import com.datastax.oss.driver.api.testinfra.ccm.CustomCcmRule; import com.datastax.oss.driver.api.testinfra.session.SessionUtils; +import com.datastax.oss.driver.categories.IsolatedTests; import com.datastax.oss.driver.internal.core.pool.ChannelPool; +import com.datastax.oss.driver.internal.core.session.DefaultSession; import com.datastax.oss.driver.internal.core.util.concurrent.CompletableFutures; import com.datastax.oss.driver.internal.core.util.concurrent.Reconnection; +import com.datastax.oss.driver.shaded.guava.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import com.google.common.collect.ImmutableSet; -import com.google.common.util.concurrent.Uninterruptibles; import com.tngtech.java.junit.dataprovider.DataProvider; import com.tngtech.java.junit.dataprovider.DataProviderRunner; import com.tngtech.java.junit.dataprovider.UseDataProvider; import java.net.InetSocketAddress; import java.time.Duration; +import java.util.Arrays; +import java.util.Collections; import java.util.List; import java.util.Map; import java.util.Set; import java.util.concurrent.CompletionStage; import java.util.concurrent.TimeUnit; import java.util.regex.Pattern; +import org.awaitility.Awaitility; import org.junit.After; import org.junit.Before; import org.junit.ClassRule; import org.junit.Test; +import org.junit.experimental.categories.Category; import org.junit.runner.RunWith; import org.slf4j.LoggerFactory; @ScyllaOnly(description = "Advanced shard awareness relies on ScyllaDB's shard aware port") @RunWith(DataProviderRunner.class) +@Category(IsolatedTests.class) public class AdvancedShardAwarenessIT { @ClassRule @@ -55,9 +62,8 @@ public class AdvancedShardAwarenessIT { Level originalLevelReconnection; private final Pattern shardMismatchPattern = Pattern.compile(".*r configuration of shard aware port.*"); - private final Pattern reconnectionPattern = + private final Pattern generalReconnectionPattern = Pattern.compile(".*Scheduling next reconnection in.*"); - Set forbiddenOccurences = ImmutableSet.of(shardMismatchPattern, reconnectionPattern); @DataProvider public static Object[][] reuseAddressOption() { @@ -92,15 +98,27 @@ public void stopCapturingLogs() { @Test @UseDataProvider("reuseAddressOption") public void should_initialize_all_channels(boolean reuseAddress) { + int expectedChannelsPerNode = 6; // Divisible by smp + String node1 = CCM_RULE.getCcmBridge().getNodeIpAddress(1); + String node2 = CCM_RULE.getCcmBridge().getNodeIpAddress(2); + Pattern reconnectionPattern1 = + Pattern.compile(".*" + Pattern.quote(node1) + ".*Scheduling next reconnection in.*"); + Pattern reconnectionPattern2 = + Pattern.compile(".*" + Pattern.quote(node2) + ".*Scheduling next reconnection in.*"); + Set forbiddenOccurences = + ImmutableSet.of(shardMismatchPattern, reconnectionPattern1, reconnectionPattern2); Map expectedOccurences = ImmutableMap.of( - Pattern.compile(".*\\.2:19042.*Reconnection attempt complete, 6/6 channels.*"), 1, - Pattern.compile(".*\\.1:19042.*Reconnection attempt complete, 6/6 channels.*"), 1, - Pattern.compile(".*Reconnection attempt complete.*"), 2, - Pattern.compile(".*\\.1:19042.*New channel added \\[.*"), 5, - Pattern.compile(".*\\.2:19042.*New channel added \\[.*"), 5, - Pattern.compile(".*\\.1:19042\\] Trying to create 5 missing channels.*"), 1, - Pattern.compile(".*\\.2:19042\\] Trying to create 5 missing channels.*"), 1); + Pattern.compile( + ".*" + + Pattern.quote(node1) + + ":19042.*Reconnection attempt complete, 6/6 channels.*"), + 1, + Pattern.compile( + ".*" + + Pattern.quote(node2) + + ":19042.*Reconnection attempt complete, 6/6 channels.*"), + 1); DriverConfigLoader loader = SessionUtils.configLoaderBuilder() .withBoolean(DefaultDriverOption.SOCKET_REUSE_ADDRESS, reuseAddress) @@ -109,48 +127,60 @@ public void should_initialize_all_channels(boolean reuseAddress) { .withInt(DefaultDriverOption.ADVANCED_SHARD_AWARENESS_PORT_HIGH, 60000) // Due to rounding up the connections per shard this will result in 6 connections per // node - .withInt(DefaultDriverOption.CONNECTION_POOL_LOCAL_SIZE, 4) + .withInt(DefaultDriverOption.CONNECTION_POOL_LOCAL_SIZE, expectedChannelsPerNode) .build(); - try (Session session = + try (CqlSession session = CqlSession.builder() .addContactPoint( new InetSocketAddress(CCM_RULE.getCcmBridge().getNodeIpAddress(1), 19042)) .withConfigLoader(loader) .build()) { - Uninterruptibles.sleepUninterruptibly(1, TimeUnit.SECONDS); + List allSessions = Collections.singletonList(session); + Awaitility.await() + .atMost(5, TimeUnit.SECONDS) + .pollInterval(500, TimeUnit.MILLISECONDS) + .until(() -> areAllPoolsFullyInitialized(allSessions, expectedChannelsPerNode)); + List logsCopy = ImmutableList.copyOf(appender.list); expectedOccurences.forEach( - (pattern, times) -> assertMatchesExactly(pattern, times, appender.list)); - forbiddenOccurences.forEach(pattern -> assertNoLogMatches(pattern, appender.list)); + (pattern, times) -> assertMatchesExactly(pattern, times, logsCopy)); + forbiddenOccurences.forEach(pattern -> assertNoLogMatches(pattern, logsCopy)); } } @Test public void should_see_mismatched_shard() { + int expectedChannelsPerNode = 66; // Divisible by smp DriverConfigLoader loader = SessionUtils.configLoaderBuilder() .withBoolean(DefaultDriverOption.CONNECTION_ADVANCED_SHARD_AWARENESS_ENABLED, true) .withInt(DefaultDriverOption.ADVANCED_SHARD_AWARENESS_PORT_LOW, 10000) .withInt(DefaultDriverOption.ADVANCED_SHARD_AWARENESS_PORT_HIGH, 60000) - .withInt(DefaultDriverOption.CONNECTION_POOL_LOCAL_SIZE, 64) + .withInt(DefaultDriverOption.CONNECTION_POOL_LOCAL_SIZE, 66) .build(); - try (Session session = + try (CqlSession session = CqlSession.builder() .addContactPoint( new InetSocketAddress(CCM_RULE.getCcmBridge().getNodeIpAddress(1), 9042)) .withConfigLoader(loader) .build()) { - Uninterruptibles.sleepUninterruptibly(1, TimeUnit.SECONDS); - assertMatchesAtLeast(shardMismatchPattern, 5, appender.list); + List allSessions = Collections.singletonList(session); + Awaitility.await() + .atMost(20, TimeUnit.SECONDS) + .pollInterval(500, TimeUnit.MILLISECONDS) + .until(() -> areAllPoolsFullyInitialized(allSessions, expectedChannelsPerNode)); + List logsCopy = ImmutableList.copyOf(appender.list); + assertMatchesAtLeast(shardMismatchPattern, 5, logsCopy); } } // There is no need to run this as a test, but it serves as a comparison @SuppressWarnings("unused") public void should_struggle_to_fill_pools() { + int expectedChannelsPerNode = 66; // Divisible by smp DriverConfigLoader loader = SessionUtils.configLoaderBuilder() .withBoolean(DefaultDriverOption.CONNECTION_ADVANCED_SHARD_AWARENESS_ENABLED, false) - .withInt(DefaultDriverOption.CONNECTION_POOL_LOCAL_SIZE, 64) + .withInt(DefaultDriverOption.CONNECTION_POOL_LOCAL_SIZE, 66) .withDuration(DefaultDriverOption.RECONNECTION_BASE_DELAY, Duration.ofMillis(200)) .withDuration(DefaultDriverOption.RECONNECTION_MAX_DELAY, Duration.ofMillis(4000)) .build(); @@ -167,18 +197,24 @@ public void should_struggle_to_fill_pools() { CqlSession session2 = CompletableFutures.getUninterruptibly(stage2); CqlSession session3 = CompletableFutures.getUninterruptibly(stage3); CqlSession session4 = CompletableFutures.getUninterruptibly(stage4); ) { - Uninterruptibles.sleepUninterruptibly(20, TimeUnit.SECONDS); - assertNoLogMatches(shardMismatchPattern, appender.list); - assertMatchesAtLeast(reconnectionPattern, 8, appender.list); + List allSessions = Arrays.asList(session1, session2, session3, session4); + Awaitility.await() + .atMost(20, TimeUnit.SECONDS) + .pollInterval(500, TimeUnit.MILLISECONDS) + .until(() -> areAllPoolsFullyInitialized(allSessions, expectedChannelsPerNode)); + List logsCopy = ImmutableList.copyOf(appender.list); + assertNoLogMatches(shardMismatchPattern, logsCopy); + assertMatchesAtLeast(generalReconnectionPattern, 8, logsCopy); } } @Test public void should_not_struggle_to_fill_pools() { + int expectedChannelsPerNode = 66; DriverConfigLoader loader = SessionUtils.configLoaderBuilder() .withBoolean(DefaultDriverOption.CONNECTION_ADVANCED_SHARD_AWARENESS_ENABLED, true) - .withInt(DefaultDriverOption.CONNECTION_POOL_LOCAL_SIZE, 66) + .withInt(DefaultDriverOption.CONNECTION_POOL_LOCAL_SIZE, expectedChannelsPerNode) .withDuration(DefaultDriverOption.RECONNECTION_BASE_DELAY, Duration.ofMillis(10)) .withDuration(DefaultDriverOption.RECONNECTION_MAX_DELAY, Duration.ofMillis(20)) .build(); @@ -196,25 +232,58 @@ public void should_not_struggle_to_fill_pools() { CqlSession session2 = CompletableFutures.getUninterruptibly(stage2); CqlSession session3 = CompletableFutures.getUninterruptibly(stage3); CqlSession session4 = CompletableFutures.getUninterruptibly(stage4); ) { - Uninterruptibles.sleepUninterruptibly(8, TimeUnit.SECONDS); + List allSessions = Arrays.asList(session1, session2, session3, session4); + Awaitility.await() + .atMost(20, TimeUnit.SECONDS) + .pollInterval(500, TimeUnit.MILLISECONDS) + .until(() -> areAllPoolsFullyInitialized(allSessions, expectedChannelsPerNode)); int tolerance = 2; // Sometimes socket ends up already in use + String node1 = CCM_RULE.getCcmBridge().getNodeIpAddress(1); + String node2 = CCM_RULE.getCcmBridge().getNodeIpAddress(2); + Pattern reconnectionPattern1 = + Pattern.compile(".*" + Pattern.quote(node1) + ".*Scheduling next reconnection in.*"); + Pattern reconnectionPattern2 = + Pattern.compile(".*" + Pattern.quote(node2) + ".*Scheduling next reconnection in.*"); Map expectedOccurences = ImmutableMap.of( - Pattern.compile(".*\\.2:19042.*Reconnection attempt complete, 66/66 channels.*"), + Pattern.compile( + ".*" + + Pattern.quote(node1) + + ":19042.*Reconnection attempt complete, 66/66 channels.*"), 1 * sessions, - Pattern.compile(".*\\.1:19042.*Reconnection attempt complete, 66/66 channels.*"), - 1 * sessions, - Pattern.compile(".*Reconnection attempt complete.*"), 2 * sessions, - Pattern.compile(".*.1:19042.*New channel added \\[.*"), 65 * sessions - tolerance, - Pattern.compile(".*.2:19042.*New channel added \\[.*"), 65 * sessions - tolerance, - Pattern.compile(".*.1:19042\\] Trying to create 65 missing channels.*"), 1 * sessions, - Pattern.compile(".*.2:19042\\] Trying to create 65 missing channels.*"), + Pattern.compile( + ".*" + + Pattern.quote(node2) + + ":19042.*Reconnection attempt complete, 66/66 channels.*"), 1 * sessions); + List logsCopy = ImmutableList.copyOf(appender.list); expectedOccurences.forEach( - (pattern, times) -> assertMatchesAtLeast(pattern, times, appender.list)); - assertNoLogMatches(shardMismatchPattern, appender.list); - assertMatchesAtMost(reconnectionPattern, tolerance, appender.list); + (pattern, times) -> assertMatchesAtLeast(pattern, times, logsCopy)); + assertNoLogMatches(shardMismatchPattern, logsCopy); + assertMatchesAtMost(reconnectionPattern1, tolerance, logsCopy); + assertMatchesAtMost(reconnectionPattern2, tolerance, logsCopy); + } + } + + private boolean areAllPoolsFullyInitialized( + List sessions, int expectedChannelsPerNode) { + for (CqlSession session : sessions) { + DefaultSession defaultSession = (DefaultSession) session; + Map pools = defaultSession.getPools(); + if (pools == null || pools.isEmpty()) { + return false; + } + + for (ChannelPool pool : pools.values()) { + if (pool == null) { + return false; + } + if (pool.size() < expectedChannelsPerNode) { + return false; + } + } } + return true; } private void assertNoLogMatches(Pattern pattern, List logs) {