Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
131 changes: 126 additions & 5 deletions src/it/java/io/weaviate/containers/Weaviate.java
Original file line number Diff line number Diff line change
@@ -1,13 +1,19 @@
package io.weaviate.containers;

import java.io.IOException;
import java.time.Duration;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.function.Function;

import org.testcontainers.containers.Network;
import org.testcontainers.containers.wait.strategy.Wait;
import org.testcontainers.lifecycle.Startable;
import org.testcontainers.weaviate.WeaviateContainer;

import io.weaviate.client6.v1.api.Config;
Expand All @@ -20,6 +26,22 @@ public class Weaviate extends WeaviateContainer {
public static String OIDC_ISSUER = "https://auth.wcs.api.weaviate.io/auth/realms/SeMI";

private volatile SharedClient clientInstance;
private final String containerName;

/**
* By default, testcontainer's name is only available after calling
* {@link #start}.
* We need to know each container's name in advance to run a cluster
* of several nodes, in which case we alse set the name manually.
*
* @see Builder#build()
*/
@Override
public String getContainerName() {
return containerName != null
? containerName
: super.getContainerName();
}

/**
* Create a new instance of WeaviateClient connected to this container if none
Expand Down Expand Up @@ -85,17 +107,22 @@ public static Weaviate.Builder custom() {
}

public static class Builder {
private String versionTag;
private String versionTag = VERSION;
private String containerName = "weaviate";
private Set<String> enableModules = new HashSet<>();
private Set<String> adminUsers = new HashSet<>();
private Set<String> viewerUsers = new HashSet<>();
private Map<String, String> environment = new HashMap<>();

public Builder() {
this.versionTag = VERSION;
enableAutoSchema(false);
}

public Builder withContainerName(String containerName) {
this.containerName = containerName;
return this;
}

public Builder withVersion(String version) {
this.versionTag = version;
return this;
Expand Down Expand Up @@ -138,6 +165,7 @@ public Builder withFilesystemBackup(String fsPath) {
environment.put("BACKUP_FILESYSTEM_PATH", fsPath);
return this;
}

public Builder withAdminUsers(String... admins) {
adminUsers.addAll(Arrays.asList(admins));
return this;
Expand Down Expand Up @@ -195,7 +223,7 @@ public Builder withOIDC(String clientId, String issuer, String usernameClaim, St
}

public Weaviate build() {
var c = new Weaviate(DOCKER_IMAGE + ":" + versionTag);
var c = new Weaviate(containerName, DOCKER_IMAGE + ":" + versionTag);

if (!enableModules.isEmpty()) {
c.withEnv("ENABLE_API_BASED_MODULES", Boolean.TRUE.toString());
Expand All @@ -217,13 +245,18 @@ public Weaviate build() {
}

environment.forEach((name, value) -> c.withEnv(name, value));
c.withCreateContainerCmdModifier(cmd -> cmd.withHostName("weaviate"));
c.withCreateContainerCmdModifier(cmd -> cmd.withHostName(containerName));
return c;
}
}

private Weaviate(String dockerImageName) {
private Weaviate() {
this("weaviate", DOCKER_IMAGE + ":" + VERSION);
}

private Weaviate(String containerName, String dockerImageName) {
super(dockerImageName);
this.containerName = containerName;
}

@Override
Expand Down Expand Up @@ -264,4 +297,92 @@ private void close(Weaviate caller) throws Exception {
public void close() throws IOException {
}
}

public static Weaviate cluster(int replicas) {
List<Weaviate> nodes = new ArrayList<>();
for (var i = 0; i < replicas; i++) {
nodes.add(Weaviate.custom()
.withContainerName("weaviate-" + i)
.build());
}
return new Cluster(nodes);
}

public static class Cluster extends Weaviate {
/** Leader and followers combined. */
private final List<Weaviate> nodes;

private final Weaviate leader;
private final List<Weaviate> followers;

private Cluster(List<Weaviate> nodes) {
assert nodes.size() > 1 : "cluster must have 1+ nodes";

this.nodes = List.copyOf(nodes);
this.leader = nodes.remove(0);
this.followers = List.copyOf(nodes);

for (var follower : followers) {
follower.dependsOn(leader);
}
setNetwork(Network.SHARED);
bindNodes(7110, 7111, 8300);
}

@Override
public WeaviateContainer dependsOn(List<? extends Startable> startables) {
leader.dependsOn(startables);
return this;
}

@Override
public void setNetwork(Network network) {
nodes.forEach(node -> node.setNetwork(network));
}

@Override
public WeaviateClient getClient() {
if (!isRunning()) {
start();
}
return leader.getClient();
}

@Override
public void start() {
followers.forEach(Startable::start); // testcontainers will resolve dependencies
}

@Override
public void stop() {
followers.forEach(Startable::stop);
leader.stop();
}

/**
* Set environment variables for inter-cluster communication.
*
* @param gossip Gossip bind port.
* @param data Data bind port.
* @param raft RAFT port.
*/
private void bindNodes(int gossip, int data, int raft) {
var publicPort = leader.getExposedPorts().get(0); // see WeaviateContainer Testcontainer.

nodes.forEach(node -> node
.withEnv("CLUSTER_GOSSIP_BIND_PORT", String.valueOf(gossip))
.withEnv("CLUSTER_DATA_BIND_PORT", String.valueOf(data))
.withEnv("RAFT_PORT", String.valueOf(raft))
.withEnv("RAFT_BOOTSTRAP_EXPECT", "1"));

followers.forEach(node -> node
.withEnv("CLUSTER_JOIN", leader.containerName + ":" + gossip)
.withEnv("RAFT_JOIN", leader.containerName)
.waitingFor(
Wait.forHttp("/v1/.well-known/ready")
.forPort(publicPort)
.forStatusCode(200)
.withStartupTimeout(Duration.ofSeconds(10))));
}
}
}
53 changes: 53 additions & 0 deletions src/it/java/io/weaviate/integration/ClusterITest.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
package io.weaviate.integration;

import java.io.IOException;

import org.assertj.core.api.Assertions;
import org.junit.Test;

import io.weaviate.ConcurrentTest;
import io.weaviate.client6.v1.api.WeaviateClient;
import io.weaviate.client6.v1.api.cluster.ShardingState;
import io.weaviate.containers.Weaviate;

public class ClusterITest extends ConcurrentTest {
private static final WeaviateClient client = Weaviate.cluster(3).getClient();

@Test
public void test_shardingState() throws IOException {
// Arrange
var nsA = ns("A");
var nsB = ns("B");

client.collections.create(nsA,
a -> a.replication(r -> r.replicationFactor(2)));
client.collections.create(nsB,
b -> b.replication(r -> r.replicationFactor(3)));

// Act
var optShardsA = client.cluster.shardingState(nsA);
var optShardsB = client.cluster.shardingState(nsB);

// Assert
var shardsA = Assertions.assertThat(optShardsA).get()
.returns(nsA, ShardingState::collection)
.extracting(ShardingState::shards)
.actual();

var shardsB = Assertions.assertThat(optShardsB).get()
.returns(nsB, ShardingState::collection)
.extracting(ShardingState::shards)
.actual();

Assertions.assertThat(shardsA).doesNotContainAnyElementsOf(shardsB);
}

@Test
public void test_listNodes() throws IOException {
// Act
var allNodes = client.cluster.listNodes();

// Assert
Assertions.assertThat(allNodes).as("total no. nodes").hasSize(3);
}
}
10 changes: 9 additions & 1 deletion src/main/java/io/weaviate/client6/v1/api/WeaviateClient.java
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

import io.weaviate.client6.v1.api.alias.WeaviateAliasClient;
import io.weaviate.client6.v1.api.backup.WeaviateBackupClient;
import io.weaviate.client6.v1.api.cluster.WeaviateClusterClient;
import io.weaviate.client6.v1.api.collections.WeaviateCollectionsClient;
import io.weaviate.client6.v1.api.rbac.groups.WeaviateGroupsClient;
import io.weaviate.client6.v1.api.rbac.roles.WeaviateRolesClient;
Expand Down Expand Up @@ -37,7 +38,7 @@ public class WeaviateClient implements AutoCloseable {

/** Client for {@code /backups} endpoints for managing backups. */
public final WeaviateBackupClient backup;

/**
* Client for {@code /authz/roles} endpoints for managing RBAC roles.
*/
Expand All @@ -53,6 +54,12 @@ public class WeaviateClient implements AutoCloseable {
*/
public final WeaviateUsersClient users;

/**
* Client for {@code /nodes} and {@code /replication} endpoints
* for managing replication and sharding.
*/
public final WeaviateClusterClient cluster;

public WeaviateClient(Config config) {
RestTransportOptions restOpt;
GrpcChannelOptions grpcOpt;
Expand Down Expand Up @@ -108,6 +115,7 @@ public WeaviateClient(Config config) {
this.roles = new WeaviateRolesClient(restTransport);
this.groups = new WeaviateGroupsClient(restTransport);
this.users = new WeaviateUsersClient(restTransport);
this.cluster = new WeaviateClusterClient(restTransport);
this.config = config;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

import io.weaviate.client6.v1.api.alias.WeaviateAliasClientAsync;
import io.weaviate.client6.v1.api.backup.WeaviateBackupClientAsync;
import io.weaviate.client6.v1.api.cluster.WeaviateClusterClientAsync;
import io.weaviate.client6.v1.api.collections.WeaviateCollectionsClient;
import io.weaviate.client6.v1.api.collections.WeaviateCollectionsClientAsync;
import io.weaviate.client6.v1.api.rbac.groups.WeaviateGroupsClientAsync;
Expand Down Expand Up @@ -52,6 +53,12 @@ public class WeaviateClientAsync implements AutoCloseable {
*/
public final WeaviateUsersClientAsync users;

/**
* Client for {@code /nodes} and {@code /replication} endpoints
* for managing replication and sharding.
*/
public final WeaviateClusterClientAsync cluster;

/**
* This constructor is blocking if {@link Authentication} configured,
* as the client will need to do the initial token exchange.
Expand Down Expand Up @@ -110,6 +117,7 @@ public WeaviateClientAsync(Config config) {
this.roles = new WeaviateRolesClientAsync(restTransport);
this.groups = new WeaviateGroupsClientAsync(restTransport);
this.users = new WeaviateUsersClientAsync(restTransport);
this.cluster = new WeaviateClusterClientAsync(restTransport);
this.collections = new WeaviateCollectionsClientAsync(restTransport, grpcTransport);
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
package io.weaviate.client6.v1.api.cluster;

import com.google.gson.annotations.SerializedName;

public record AsyncReplicationStatus(
@SerializedName("objectsPropagated") long objectsPropagated,
@SerializedName("startDiffTimeUnixMillis") long startDiffTimeUnixMillis,
@SerializedName("targetNode") String targetNode) {
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
package io.weaviate.client6.v1.api.cluster;

import com.google.gson.annotations.SerializedName;

public record CollectionStats(
@SerializedName("shardCount") int shardCount,
@SerializedName("objectCount") long objectCount) {
}
Loading