Skip to content

Commit

Permalink
Expose ShardingInfo
Browse files Browse the repository at this point in the history
Adds ShardingInfo field to DefaultNode and getShardingInfo() to Node
interface. Adds NodeShardingInfo interface as public API.

With this change, having a token and having determined the Node it belongs to,
users can further use NodeShardingInfo to determine the shard it belongs to.

Fixes #232.
  • Loading branch information
Bouncheck authored and Piotr Grabowski committed Jan 15, 2024
1 parent 31a379b commit e561a2d
Show file tree
Hide file tree
Showing 6 changed files with 140 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -214,4 +214,13 @@ public interface Node {
*/
@Nullable
UUID getSchemaVersion();

/**
* Node's sharding information.
*
* <p>May be null if the node is not a Scylla node or the connection pool to the node was never
* created.
*/
@Nullable
NodeShardingInfo getShardingInfo();
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
package com.datastax.oss.driver.api.core.metadata;

import com.datastax.oss.driver.api.core.metadata.token.Token;

/** Holds sharding information for a particular Node. */
public interface NodeShardingInfo {

public int getShardsCount();

/**
* Returns a shardId for a given Token.
*
* <p>Accepts all types of Tokens but if the Token is not an instance of {@link
* com.datastax.oss.driver.internal.core.metadata.token.TokenLong64} then the return value could
* be not meaningful (e.g. random shard). This method does not verify if the given Token belongs
* to the Node.
*/
public int shardId(Token t);
}
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import com.datastax.oss.driver.internal.core.context.InternalDriverContext;
import com.datastax.oss.driver.internal.core.metrics.NodeMetricUpdater;
import com.datastax.oss.driver.internal.core.metrics.NoopNodeMetricUpdater;
import com.datastax.oss.driver.internal.core.protocol.ShardingInfo;
import edu.umd.cs.findbugs.annotations.NonNull;
import edu.umd.cs.findbugs.annotations.Nullable;
import java.io.Serializable;
Expand All @@ -35,8 +36,8 @@
import net.jcip.annotations.ThreadSafe;

/**
* Implementation note: all the mutable state in this class is read concurrently, but only mutated
* from {@link MetadataManager}'s admin thread.
* Implementation note: (almost) all the mutable state in this class is read concurrently, but only
* mutated from {@link MetadataManager}'s admin thread. Node's ShardingInfo is an exception.
*/
@ThreadSafe
public class DefaultNode implements Node, Serializable {
Expand Down Expand Up @@ -67,6 +68,10 @@ public class DefaultNode implements Node, Serializable {

volatile NodeDistance distance;

// Initially null. A copy of ShardingInfo. Updated with values by DriverChannel during pool
// initialization.
private volatile ShardingInfo shardingInfo;

public DefaultNode(EndPoint endPoint, InternalDriverContext context) {
this.endPoint = endPoint;
this.state = NodeState.UNKNOWN;
Expand All @@ -77,6 +82,7 @@ public DefaultNode(EndPoint endPoint, InternalDriverContext context) {
// problem because the node updater only needs the connect address to initialize.
this.metricUpdater = context.getMetricsFactory().newNodeUpdater(this);
this.upSinceMillis = -1;
this.shardingInfo = null;
}

@NonNull
Expand Down Expand Up @@ -193,4 +199,14 @@ public String toString() {
public Set<String> getRawTokens() {
return rawTokens;
}

@Nullable
@Override
public ShardingInfo getShardingInfo() {
return shardingInfo;
}

public void setShardingInfo(ShardingInfo shardingInfo) {
this.shardingInfo = shardingInfo;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -340,6 +340,7 @@ private void addChannel(DriverChannel c) {

private void initialize(DriverChannel c) {
shardingInfo = c.getShardingInfo();
((DefaultNode) node).setShardingInfo(shardingInfo);
int wanted = getConfiguredSize(distance);
int shardsCount = shardingInfo == null ? 1 : shardingInfo.getShardsCount();
wantedCount = wanted / shardsCount + (wanted % shardsCount > 0 ? 1 : 0);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,14 +21,15 @@
*/
package com.datastax.oss.driver.internal.core.protocol;

import com.datastax.oss.driver.api.core.metadata.NodeShardingInfo;
import com.datastax.oss.driver.api.core.metadata.token.Token;
import com.datastax.oss.driver.internal.core.metadata.token.TokenLong64;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ThreadLocalRandom;

/** Keeps the information the driver maintains on data layout of a given node. */
public class ShardingInfo {
public class ShardingInfo implements NodeShardingInfo {
private static final String SCYLLA_SHARD_PARAM_KEY = "SCYLLA_SHARD";
private static final String SCYLLA_NR_SHARDS_PARAM_KEY = "SCYLLA_NR_SHARDS";
private static final String SCYLLA_PARTITIONER = "SCYLLA_PARTITIONER";
Expand All @@ -48,6 +49,7 @@ private ShardingInfo(
this.shardingIgnoreMSB = shardingIgnoreMSB;
}

@Override
public int getShardsCount() {
return shardsCount;
}
Expand All @@ -60,6 +62,7 @@ public String getShardingAlgorithm() {
return shardingAlgorithm;
}

@Override
public int shardId(Token t) {
if (!(t instanceof TokenLong64)) {
return ThreadLocalRandom.current().nextInt(shardsCount);
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
package com.datastax.oss.driver.examples.basic;

import com.datastax.oss.driver.api.core.CqlIdentifier;
import com.datastax.oss.driver.api.core.CqlSession;
import com.datastax.oss.driver.api.core.DefaultProtocolVersion;
import com.datastax.oss.driver.api.core.cql.ResultSet;
import com.datastax.oss.driver.api.core.cql.SimpleStatement;
import com.datastax.oss.driver.api.core.cql.TraceEvent;
import com.datastax.oss.driver.api.core.metadata.Metadata;
import com.datastax.oss.driver.api.core.metadata.Node;
import com.datastax.oss.driver.api.core.metadata.TokenMap;
import com.datastax.oss.driver.api.core.metadata.token.Token;
import com.datastax.oss.driver.api.core.type.codec.TypeCodecs;
import java.nio.ByteBuffer;
import java.util.Set;

/**
* Demonstrates usage of TokenMap and NodeShardingInfo Needs a Scylla cluster to be running locally
* or adjustment of session builder.
*/
public class TokenMapAndShardIdLookup {

private static String CREATE_KEYSPACE =
"CREATE KEYSPACE IF NOT EXISTS tokenmap_example_ks "
+ "WITH replication = {"
+ "'class': 'SimpleStrategy', "
+ "'replication_factor': 1"
+ "}";

private static String CREATE_TABLE =
""
+ "CREATE TABLE IF NOT EXISTS tokenmap_example_ks.example_tab ("
+ "my_column bigint,"
+ "PRIMARY KEY (my_column)"
+ ")";

private static String INSERT_COLUMN =
"INSERT INTO tokenmap_example_ks.example_tab (my_column) VALUES (2)";

private static String SELECT_COLUMN =
"SELECT * FROM tokenmap_example_ks.example_tab WHERE my_column = 2";

private static ByteBuffer PARTITION_KEY = TypeCodecs.BIGINT.encode(2L, DefaultProtocolVersion.V3);

public static void main(String[] args) {

try (CqlSession session = CqlSession.builder().build()) {

System.out.printf("Connected session: %s%n", session.getName());

session.execute(CREATE_KEYSPACE);
session.execute(CREATE_TABLE);
session.execute(INSERT_COLUMN);

Metadata metadata = session.refreshSchema();

System.out.println("Prepared example data");

TokenMap tokenMap = metadata.getTokenMap().get();

Set<Node> nodes =
tokenMap.getReplicas(CqlIdentifier.fromCql("tokenmap_example_ks"), PARTITION_KEY);
System.out.println("Replica set size: " + nodes.size());

Token token = tokenMap.newToken(PARTITION_KEY);
assert nodes.size() > 0;
Node node = nodes.iterator().next();

assert node.getShardingInfo() != null;
int shardId = node.getShardingInfo().shardId(token);

System.out.println(
"Hardcoded partition key should belong to shard number "
+ shardId
+ " (on Node: "
+ node
+ ")");

System.out.println("You can compare it with SELECT query trace:");
// If there is only 1 node, then the SELECT has to hit the one we did shardId calculation for.
SimpleStatement statement = SimpleStatement.builder(SELECT_COLUMN).setTracing(true).build();
ResultSet rs = session.execute(statement);

for (TraceEvent event : rs.getExecutionInfo().getQueryTrace().getEvents()) {
System.out.println(event);
}
}
}
}

0 comments on commit e561a2d

Please sign in to comment.