Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Reject task on worker still starting up #21921

Merged
merged 4 commits into from
May 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
11 changes: 5 additions & 6 deletions core/trino-main/src/main/java/io/trino/server/Server.java
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@
import java.util.Optional;
import java.util.Set;

import static com.google.common.collect.MoreCollectors.toOptional;
import static io.airlift.discovery.client.ServiceAnnouncement.ServiceAnnouncementBuilder;
import static io.airlift.discovery.client.ServiceAnnouncement.serviceAnnouncement;
import static io.trino.server.TrinoSystemRequirements.verifyJvmRequirements;
Expand Down Expand Up @@ -268,12 +269,10 @@ private static void updateConnectorIds(Announcer announcer, CatalogManager catal

private static ServiceAnnouncement getTrinoAnnouncement(Set<ServiceAnnouncement> announcements)
{
for (ServiceAnnouncement announcement : announcements) {
if (announcement.getType().equals("trino")) {
return announcement;
}
}
throw new IllegalArgumentException("Trino announcement not found: " + announcements);
return announcements.stream()
.filter(announcement -> announcement.getType().equals("trino"))
.collect(toOptional())
.orElseThrow(() -> new IllegalArgumentException("Trino announcement not found: " + announcements));
}

private static void logLocation(Logger log, String name, Path path)
Expand Down
31 changes: 31 additions & 0 deletions core/trino-main/src/main/java/io/trino/server/TaskResource.java
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@ public class TaskResource
private static final Duration ADDITIONAL_WAIT_TIME = new Duration(5, SECONDS);
private static final Duration DEFAULT_MAX_WAIT_TIME = new Duration(2, SECONDS);

private final StartupStatus startupStatus;
private final SqlTaskManager taskManager;
private final SessionPropertyManager sessionPropertyManager;
private final Executor responseExecutor;
Expand All @@ -106,12 +107,14 @@ public class TaskResource

@Inject
public TaskResource(
StartupStatus startupStatus,
SqlTaskManager taskManager,
SessionPropertyManager sessionPropertyManager,
@ForAsyncHttp BoundedExecutor responseExecutor,
@ForAsyncHttp ScheduledExecutorService timeoutExecutor,
FailureInjector failureInjector)
{
this.startupStatus = requireNonNull(startupStatus, "startupStatus is null");
this.taskManager = requireNonNull(taskManager, "taskManager is null");
this.sessionPropertyManager = requireNonNull(sessionPropertyManager, "sessionPropertyManager is null");
this.responseExecutor = requireNonNull(responseExecutor, "responseExecutor is null");
Expand Down Expand Up @@ -143,6 +146,9 @@ public void createOrUpdateTask(
@Suspended AsyncResponse asyncResponse)
{
requireNonNull(taskUpdateRequest, "taskUpdateRequest is null");
if (failRequestIfInvalid(asyncResponse)) {
return;
}

Session session = taskUpdateRequest.session().toSession(sessionPropertyManager, taskUpdateRequest.extraCredentials(), taskUpdateRequest.exchangeEncryptionKey());

Expand Down Expand Up @@ -179,6 +185,9 @@ public void getTaskInfo(
@Suspended AsyncResponse asyncResponse)
{
requireNonNull(taskId, "taskId is null");
if (failRequestIfInvalid(asyncResponse)) {
return;
}

if (injectFailure(taskManager.getTraceToken(taskId), taskId, RequestType.GET_TASK_INFO, asyncResponse)) {
return;
Expand Down Expand Up @@ -224,6 +233,9 @@ public void getTaskStatus(
@Suspended AsyncResponse asyncResponse)
{
requireNonNull(taskId, "taskId is null");
if (failRequestIfInvalid(asyncResponse)) {
return;
}

if (injectFailure(taskManager.getTraceToken(taskId), taskId, RequestType.GET_TASK_STATUS, asyncResponse)) {
return;
Expand Down Expand Up @@ -265,6 +277,9 @@ public void acknowledgeAndGetNewDynamicFilterDomains(
{
requireNonNull(taskId, "taskId is null");
requireNonNull(currentDynamicFiltersVersion, "currentDynamicFiltersVersion is null");
if (failRequestIfInvalid(asyncResponse)) {
return;
}

if (injectFailure(taskManager.getTraceToken(taskId), taskId, RequestType.ACKNOWLEDGE_AND_GET_NEW_DYNAMIC_FILTER_DOMAINS, asyncResponse)) {
return;
Expand Down Expand Up @@ -397,6 +412,22 @@ public void pruneCatalogs(Set<CatalogHandle> catalogHandles)
taskManager.pruneCatalogs(catalogHandles);
}

private boolean failRequestIfInvalid(AsyncResponse asyncResponse)
{
if (!startupStatus.isStartupComplete()) {
// When worker node is restarted after a crash, coordinator may be still unaware of the situation and may attempt to schedule tasks on it.
// Ideally the coordinator should not schedule tasks on worker that is not ready, but in pipelined execution there is currently no way to move a task.
// Accepting a request too early will likely lead to some failure and HTTP 500 (INTERNAL_SERVER_ERROR) response. The coordinator won't retry on this.
// Send 503 (SERVICE_UNAVAILABLE) so that request is retried.
asyncResponse.resume(Response.status(Status.SERVICE_UNAVAILABLE)
.type(MediaType.TEXT_PLAIN_TYPE)
.entity("The server is not fully started yet ")
.build());
return true;
}
return false;
}

private boolean injectFailure(
Optional<String> traceToken,
TaskId taskId,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@
import io.trino.server.SessionPropertyDefaults;
import io.trino.server.SessionSupplier;
import io.trino.server.ShutdownAction;
import io.trino.server.StartupStatus;
import io.trino.server.security.CertificateAuthenticatorManager;
import io.trino.server.security.ServerSecurityModule;
import io.trino.spi.ErrorType;
Expand Down Expand Up @@ -424,6 +425,7 @@ private TestingTrinoServer(
eventListeners.forEach(eventListenerManager::addEventListener);

getFutureValue(injector.getInstance(Announcer.class).forceAnnounce());
injector.getInstance(StartupStatus.class).startupComplete();

refreshNodes();
}
Expand Down
13 changes: 12 additions & 1 deletion testing/trino-testing/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,11 @@
<artifactId>configuration</artifactId>
</dependency>

<dependency>
<groupId>io.airlift</groupId>
<artifactId>http-server</artifactId>
</dependency>

<dependency>
<groupId>io.airlift</groupId>
<artifactId>log</artifactId>
Expand Down Expand Up @@ -200,6 +205,12 @@
<artifactId>assertj-core</artifactId>
</dependency>

<dependency>
<groupId>org.eclipse.jetty</groupId>
<artifactId>jetty-server</artifactId>
<version>12.0.9</version>
</dependency>

<dependency>
<groupId>org.jdbi</groupId>
<artifactId>jdbi3-core</artifactId>
Expand Down Expand Up @@ -246,4 +257,4 @@
<scope>test</scope>
</dependency>
</dependencies>
</project>
</project>
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import com.google.inject.Key;
import com.google.inject.Module;
import io.airlift.discovery.server.testing.TestingDiscoveryServer;
import io.airlift.http.server.HttpServer;
import io.airlift.log.Logger;
import io.airlift.log.Logging;
import io.opentelemetry.sdk.testing.exporter.InMemorySpanExporter;
Expand Down Expand Up @@ -57,10 +58,13 @@
import io.trino.sql.planner.Plan;
import io.trino.testing.containers.OpenTracingCollector;
import io.trino.transaction.TransactionManager;
import org.eclipse.jetty.server.Connector;
import org.eclipse.jetty.server.Server;
import org.intellij.lang.annotations.Language;

import java.io.IOException;
import java.io.UncheckedIOException;
import java.lang.reflect.Field;
import java.net.URI;
import java.nio.file.Path;
import java.util.HashMap;
Expand All @@ -79,6 +83,7 @@
import static com.google.common.base.Preconditions.checkState;
import static com.google.common.base.Throwables.throwIfUnchecked;
import static com.google.common.base.Verify.verify;
import static com.google.common.collect.Iterables.getOnlyElement;
import static com.google.inject.util.Modules.EMPTY_MODULE;
import static io.airlift.log.Level.DEBUG;
import static io.airlift.log.Level.ERROR;
Expand All @@ -88,7 +93,9 @@
import static io.trino.execution.querystats.PlanOptimizersStatsCollector.createPlanOptimizersStatsCollector;
import static java.lang.Boolean.parseBoolean;
import static java.lang.System.getenv;
import static java.util.Arrays.asList;
import static java.util.Objects.requireNonNull;
import static org.assertj.core.api.Assertions.assertThat;

public class DistributedQueryRunner
implements QueryRunner
Expand All @@ -100,7 +107,7 @@ public class DistributedQueryRunner
private TestingDiscoveryServer discoveryServer;
private TestingTrinoServer coordinator;
private Optional<TestingTrinoServer> backupCoordinator;
private Runnable registerNewWorker;
private Consumer<Map<String, String>> registerNewWorker;
private final InMemorySpanExporter spanExporter = InMemorySpanExporter.create();
private final List<TestingTrinoServer> servers = new CopyOnWriteArrayList<>();
private final List<FunctionBundle> functionBundles = new CopyOnWriteArrayList<>(ImmutableList.of(CustomFunctionBundle.CUSTOM_FUNCTIONS));
Expand Down Expand Up @@ -149,11 +156,14 @@ private DistributedQueryRunner(
extraCloseables.forEach(closeable -> closer.register(() -> closeUnchecked(closeable)));
log.debug("Created TestingDiscoveryServer in %s", nanosSince(discoveryStart));

registerNewWorker = () -> {
registerNewWorker = additionalWorkerProperties -> {
@SuppressWarnings("resource")
TestingTrinoServer ignored = createServer(
false,
extraProperties,
ImmutableMap.<String, String>builder()
.putAll(extraProperties)
.putAll(additionalWorkerProperties)
.buildOrThrow(),
environment,
additionalModule,
baseDataDir,
Expand All @@ -163,7 +173,7 @@ private DistributedQueryRunner(
};

for (int i = 0; i < workerCount; i++) {
registerNewWorker.run();
registerNewWorker.accept(Map.of());
}

Map<String, String> extraCoordinatorProperties = new HashMap<>();
Expand Down Expand Up @@ -317,11 +327,37 @@ private static TestingTrinoServer createTestingTrinoServer(
public void addServers(int nodeCount)
{
for (int i = 0; i < nodeCount; i++) {
registerNewWorker.run();
registerNewWorker.accept(Map.of());
}
ensureNodesGloballyVisible();
}

/**
* Simulate worker restart as e.g. in Kubernetes after a pod is killed.
*/
public void restartWorker(TestingTrinoServer server)
throws Exception
{
URI baseUrl = server.getBaseUrl();
checkState(servers.remove(server), "Server not found: %s", server);
HttpServer workerHttpServer = server.getInstance(Key.get(HttpServer.class));
// Prevent any HTTP communication with the worker, as if the worker process was killed.
Field serverField = HttpServer.class.getDeclaredField("server");
serverField.setAccessible(true);
Connector httpConnector = getOnlyElement(asList(((Server) serverField.get(workerHttpServer)).getConnectors()));
httpConnector.stop();
server.close();

Map<String, String> reusePort = Map.of("http-server.http.port", Integer.toString(baseUrl.getPort()));
registerNewWorker.accept(reusePort);
// Verify the address was reused.
assertThat(servers.stream()
.map(TestingTrinoServer::getBaseUrl)
.filter(baseUrl::equals))
.hasSize(1);
// Do not wait for new server to be fully registered with other servers
}

private void ensureNodesGloballyVisible()
{
for (TestingTrinoServer server : servers) {
Expand Down Expand Up @@ -589,7 +625,7 @@ public final void close()
discoveryServer = null;
coordinator = null;
backupCoordinator = Optional.empty();
registerNewWorker = () -> {
registerNewWorker = _ -> {
throw new IllegalStateException("Already closed");
};
servers.clear();
Expand Down