Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import com.datastax.driver.core.Frame.Header;
import com.datastax.driver.core.Requests.QueryFlag;
import com.datastax.driver.core.exceptions.UnsupportedFeatureException;
import com.google.common.annotations.Beta;
import com.google.common.collect.ImmutableList;
import java.nio.ByteBuffer;
import java.util.ArrayList;
Expand Down Expand Up @@ -267,13 +268,10 @@ public BatchStatement setSerialConsistencyLevel(ConsistencyLevel serialConsisten

@Override
public ByteBuffer getRoutingKey(ProtocolVersion protocolVersion, CodecRegistry codecRegistry) {
for (Statement statement : statements) {
if (statement instanceof StatementWrapper)
statement = ((StatementWrapper) statement).getWrappedStatement();
ByteBuffer rk = statement.getRoutingKey(protocolVersion, codecRegistry);
if (rk != null) return rk;
}
return null;
Statement routingStatement = getRoutingStatement(protocolVersion, codecRegistry);
return routingStatement == null
? null
: routingStatement.getRoutingKey(protocolVersion, codecRegistry);
}

@Override
Expand All @@ -298,6 +296,22 @@ void ensureAllSet() {
if (statement instanceof BoundStatement) ((BoundStatement) statement).ensureAllSet();
}

/**
* Returns the first statement in this batch that provides a routing key for the given protocol
* version and codec registry.
*/
@Beta
public Statement getRoutingStatement(
ProtocolVersion protocolVersion, CodecRegistry codecRegistry) {
for (Statement statement : statements) {
if (statement instanceof StatementWrapper)
statement = ((StatementWrapper) statement).getWrappedStatement();
ByteBuffer rk = statement.getRoutingKey(protocolVersion, codecRegistry);
if (rk != null) return statement;
}
return null;
}

static class IdAndValues {

public final List<Object> ids;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
*/
package com.datastax.driver.core.policies;

import com.datastax.driver.core.BatchStatement;
import com.datastax.driver.core.BoundStatement;
import com.datastax.driver.core.Cluster;
import com.datastax.driver.core.CodecRegistry;
Expand Down Expand Up @@ -429,16 +430,7 @@ public Iterator<Host> newQueryPlan(final String loggedKeyspace, final Statement
if (partitionKey == null || keyspace == null)
return childPolicy.newQueryPlan(keyspace, statement);

String tableName = null;
ColumnDefinitions defs = null;
if (statement instanceof BoundStatement) {
defs = ((BoundStatement) statement).preparedStatement().getVariables();
} else if (statement instanceof PreparedStatement) {
defs = ((PreparedStatement) statement).getVariables();
}
if (defs != null && defs.size() > 0) {
tableName = defs.getTable(0);
}
String tableName = getRoutingTable(statement);

final List<Host> replicas =
clusterMetadata.getReplicasList(
Expand All @@ -453,6 +445,28 @@ public Iterator<Host> newQueryPlan(final String loggedKeyspace, final Statement
}
}

private String getRoutingTable(Statement statement) {
ColumnDefinitions defs = getRoutingVariables(statement);
return (defs == null || defs.size() == 0) ? null : defs.getTable(0);
}

private ColumnDefinitions getRoutingVariables(Statement statement) {
Statement target = statement;
if (statement instanceof BatchStatement) {
target = ((BatchStatement) statement).getRoutingStatement(protocolVersion, codecRegistry);
if (target == null) {
return null;
}
}

if (target instanceof BoundStatement) {
return ((BoundStatement) target).preparedStatement().getVariables();
} else if (target instanceof PreparedStatement) {
return ((PreparedStatement) target).getVariables();
}
return null;
}

private QueryOptions.RequestRoutingMethod getRequestRouting(Statement statement) {
if (!statement.isLWT() || defaultLwtRequestRoutingMethod == null) {
return QueryOptions.RequestRoutingMethod.REGULAR;
Expand Down
23 changes: 23 additions & 0 deletions driver-core/src/test/java/com/datastax/driver/core/TabletsIT.java
Original file line number Diff line number Diff line change
Expand Up @@ -298,6 +298,29 @@ public void should_receive_each_tablet_exactly_once() {
}
}

@Test(groups = "short")
public void batch_statement_should_deliver_tablet_info_and_route_properly() {
prepareCluster();
Session session = newSession();
try {
session
.getCluster()
.getMetadata()
.getTabletMap()
.removeTableMappings(KEYSPACE_NAME.toLowerCase());

PreparedStatement preparedStatement = session.prepare(STMT_INSERT);
Assert.assertTrue(
executeOnAllHostsAndReturnIfResultHasTabletsInfo(session, preparedStatement.bind(2, 2)));
Assert.assertTrue(waitSessionLearnedTabletInfo(session));

BatchStatement routedBatch = new BatchStatement().add(preparedStatement.bind(2, 2));
Assert.assertTrue(checkIfRoutedProperly(session, routedBatch));
} finally {
session.close();
}
}

private static boolean waitSessionLearnedTabletInfo(Session session) {
if (isSessionLearnedTabletInfo(session)) {
return true;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,12 +29,16 @@
import static com.datastax.driver.core.policies.TokenAwarePolicy.ReplicaOrdering.TOPOLOGICAL;
import static org.mockito.Matchers.any;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.never;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;

import com.datastax.driver.core.BatchStatement;
import com.datastax.driver.core.BoundStatement;
import com.datastax.driver.core.CCMBridge;
import com.datastax.driver.core.Cluster;
import com.datastax.driver.core.CodecRegistry;
import com.datastax.driver.core.ColumnDefinitions;
import com.datastax.driver.core.Configuration;
import com.datastax.driver.core.Host;
import com.datastax.driver.core.HostDistance;
Expand Down Expand Up @@ -155,6 +159,65 @@ public void should_create_random_order() {
assertThat(queryPlan).containsOnlyOnce(host1, host2, host3, host4).endsWith(host4, host3);
}

@Test(groups = "unit")
public void should_use_table_name_from_bound_statement_for_tablet_routing() {
// given
BoundStatement bound = newBoundStatement("tablets_table", routingKey);
when(metadata.getReplicasList(Metadata.quote(KEYSPACE), "tablets_table", null, routingKey))
.thenReturn(Lists.newArrayList(host1, host2));
when(childPolicy.newQueryPlan(KEYSPACE, bound))
.thenReturn(Lists.newArrayList(host4, host3, host2, host1).iterator());

TokenAwarePolicy policy = new TokenAwarePolicy(childPolicy, TOPOLOGICAL);
policy.init(cluster, null);

// when
Iterator<Host> queryPlan = policy.newQueryPlan(KEYSPACE, bound);

// then
assertThat(queryPlan).containsExactly(host1, host2, host4, host3);
verify(metadata).getReplicasList(Metadata.quote(KEYSPACE), "tablets_table", null, routingKey);
}

@Test(groups = "unit")
public void should_use_table_name_from_routed_statement_in_batch_for_tablet_routing() {
// given
BoundStatement skippedBound = newBoundStatement("ignored_table", null);
BoundStatement routedBound = newBoundStatement("tablets_table", routingKey);

BatchStatement batch = new BatchStatement().add(skippedBound).add(routedBound);
when(metadata.getReplicasList(Metadata.quote(KEYSPACE), "tablets_table", null, routingKey))
.thenReturn(Lists.newArrayList(host1, host2));
when(childPolicy.newQueryPlan(KEYSPACE, batch))
.thenReturn(Lists.newArrayList(host4, host3, host2, host1).iterator());

TokenAwarePolicy policy = new TokenAwarePolicy(childPolicy, TOPOLOGICAL);
policy.init(cluster, null);

// when
Iterator<Host> queryPlan = policy.newQueryPlan(KEYSPACE, batch);

// then
assertThat(queryPlan).containsExactly(host1, host2, host4, host3);
verify(metadata).getReplicasList(Metadata.quote(KEYSPACE), "tablets_table", null, routingKey);
verify(metadata, never())
.getReplicasList(Metadata.quote(KEYSPACE), "ignored_table", null, routingKey);
}

private BoundStatement newBoundStatement(String table, ByteBuffer routingKey) {
BoundStatement bound = mock(BoundStatement.class);
PreparedStatement prepared = mock(PreparedStatement.class);
ColumnDefinitions variables = mock(ColumnDefinitions.class);
when(bound.getKeyspace()).thenReturn(KEYSPACE);
when(bound.getRoutingKey(any(ProtocolVersion.class), any(CodecRegistry.class)))
.thenReturn(routingKey);
when(bound.preparedStatement()).thenReturn(prepared);
when(prepared.getVariables()).thenReturn(variables);
when(variables.size()).thenReturn(1);
when(variables.getTable(0)).thenReturn(table);
return bound;
}

@Test(groups = "unit", dataProvider = "shuffleProvider")
public void should_prioritize_local_replicas_for_lwt(TokenAwarePolicy.ReplicaOrdering ordering) {
// given
Expand Down
Loading