diff --git a/presto-raptor/src/main/java/com/facebook/presto/raptor/storage/ShardEjector.java b/presto-raptor/src/main/java/com/facebook/presto/raptor/storage/ShardEjector.java index 33a54845cb38..bd3f9793ea17 100644 --- a/presto-raptor/src/main/java/com/facebook/presto/raptor/storage/ShardEjector.java +++ b/presto-raptor/src/main/java/com/facebook/presto/raptor/storage/ShardEjector.java @@ -13,6 +13,7 @@ */ package com.facebook.presto.raptor.storage; +import com.facebook.presto.raptor.NodeSupplier; import com.facebook.presto.raptor.RaptorConnectorId; import com.facebook.presto.raptor.backup.BackupStore; import com.facebook.presto.raptor.metadata.ShardManager; @@ -45,7 +46,6 @@ import java.util.concurrent.ThreadLocalRandom; import java.util.concurrent.atomic.AtomicBoolean; -import static com.facebook.presto.spi.NodeState.ACTIVE; import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.base.Preconditions.checkState; import static com.google.common.collect.Maps.filterKeys; @@ -64,7 +64,8 @@ public class ShardEjector { private static final Logger log = Logger.get(ShardEjector.class); - private final NodeManager nodeManager; + private final String currentNode; + private final NodeSupplier nodeSupplier; private final ShardManager shardManager; private final StorageService storageService; private final Duration interval; @@ -79,13 +80,15 @@ public class ShardEjector @Inject public ShardEjector( NodeManager nodeManager, + NodeSupplier nodeSupplier, ShardManager shardManager, StorageService storageService, StorageManagerConfig config, Optional backupStore, RaptorConnectorId connectorId) { - this(nodeManager, + this(nodeManager.getCurrentNode().getNodeIdentifier(), + nodeSupplier, shardManager, storageService, config.getShardEjectorInterval(), @@ -94,14 +97,16 @@ public ShardEjector( } public ShardEjector( - NodeManager nodeManager, + String currentNode, + NodeSupplier nodeSupplier, ShardManager shardManager, StorageService storageService, Duration interval, Optional backupStore, String connectorId) { - this.nodeManager = requireNonNull(nodeManager, "nodeManager is null"); + this.currentNode = requireNonNull(currentNode, "currentNode is null"); + this.nodeSupplier = requireNonNull(nodeSupplier, "nodeSupplier is null"); this.shardManager = requireNonNull(shardManager, "shardManager is null"); this.storageService = requireNonNull(storageService, "storageService is null"); this.interval = requireNonNull(interval, "interval is null"); @@ -167,7 +172,7 @@ void process() // get the size of assigned shards for each node Map nodes = shardManager.getNodeBytes(); - Set activeNodes = nodeManager.getNodes(ACTIVE).stream() + Set activeNodes = nodeSupplier.getWorkerNodes().stream() .map(Node::getNodeIdentifier) .collect(toSet()); @@ -179,7 +184,6 @@ void process() } // get current node size - String currentNode = nodeManager.getCurrentNode().getNodeIdentifier(); if (!nodes.containsKey(currentNode)) { return; }