diff --git a/core/trino-main/src/main/java/io/trino/server/Server.java b/core/trino-main/src/main/java/io/trino/server/Server.java index 68ff9d8040f81..58deacb9b4533 100644 --- a/core/trino-main/src/main/java/io/trino/server/Server.java +++ b/core/trino-main/src/main/java/io/trino/server/Server.java @@ -77,6 +77,7 @@ import java.util.Optional; import java.util.Set; +import static com.google.common.collect.MoreCollectors.toOptional; import static io.airlift.discovery.client.ServiceAnnouncement.ServiceAnnouncementBuilder; import static io.airlift.discovery.client.ServiceAnnouncement.serviceAnnouncement; import static io.trino.server.TrinoSystemRequirements.verifyJvmRequirements; @@ -268,12 +269,10 @@ private static void updateConnectorIds(Announcer announcer, CatalogManager catal private static ServiceAnnouncement getTrinoAnnouncement(Set announcements) { - for (ServiceAnnouncement announcement : announcements) { - if (announcement.getType().equals("trino")) { - return announcement; - } - } - throw new IllegalArgumentException("Trino announcement not found: " + announcements); + return announcements.stream() + .filter(announcement -> announcement.getType().equals("trino")) + .collect(toOptional()) + .orElseThrow(() -> new IllegalArgumentException("Trino announcement not found: " + announcements)); } private static void logLocation(Logger log, String name, Path path) diff --git a/core/trino-main/src/main/java/io/trino/server/TaskResource.java b/core/trino-main/src/main/java/io/trino/server/TaskResource.java index c8e3d5660c290..e714769025f09 100644 --- a/core/trino-main/src/main/java/io/trino/server/TaskResource.java +++ b/core/trino-main/src/main/java/io/trino/server/TaskResource.java @@ -96,6 +96,7 @@ public class TaskResource private static final Duration ADDITIONAL_WAIT_TIME = new Duration(5, SECONDS); private static final Duration DEFAULT_MAX_WAIT_TIME = new Duration(2, SECONDS); + private final StartupStatus startupStatus; private final SqlTaskManager taskManager; private final SessionPropertyManager sessionPropertyManager; private final Executor responseExecutor; @@ -106,12 +107,14 @@ public class TaskResource @Inject public TaskResource( + StartupStatus startupStatus, SqlTaskManager taskManager, SessionPropertyManager sessionPropertyManager, @ForAsyncHttp BoundedExecutor responseExecutor, @ForAsyncHttp ScheduledExecutorService timeoutExecutor, FailureInjector failureInjector) { + this.startupStatus = requireNonNull(startupStatus, "startupStatus is null"); this.taskManager = requireNonNull(taskManager, "taskManager is null"); this.sessionPropertyManager = requireNonNull(sessionPropertyManager, "sessionPropertyManager is null"); this.responseExecutor = requireNonNull(responseExecutor, "responseExecutor is null"); @@ -143,6 +146,9 @@ public void createOrUpdateTask( @Suspended AsyncResponse asyncResponse) { requireNonNull(taskUpdateRequest, "taskUpdateRequest is null"); + if (failRequestIfInvalid(asyncResponse)) { + return; + } Session session = taskUpdateRequest.session().toSession(sessionPropertyManager, taskUpdateRequest.extraCredentials(), taskUpdateRequest.exchangeEncryptionKey()); @@ -179,6 +185,9 @@ public void getTaskInfo( @Suspended AsyncResponse asyncResponse) { requireNonNull(taskId, "taskId is null"); + if (failRequestIfInvalid(asyncResponse)) { + return; + } if (injectFailure(taskManager.getTraceToken(taskId), taskId, RequestType.GET_TASK_INFO, asyncResponse)) { return; @@ -224,6 +233,9 @@ public void getTaskStatus( @Suspended AsyncResponse asyncResponse) { requireNonNull(taskId, "taskId is null"); + if (failRequestIfInvalid(asyncResponse)) { + return; + } if (injectFailure(taskManager.getTraceToken(taskId), taskId, RequestType.GET_TASK_STATUS, asyncResponse)) { return; @@ -265,6 +277,9 @@ public void acknowledgeAndGetNewDynamicFilterDomains( { requireNonNull(taskId, "taskId is null"); requireNonNull(currentDynamicFiltersVersion, "currentDynamicFiltersVersion is null"); + if (failRequestIfInvalid(asyncResponse)) { + return; + } if (injectFailure(taskManager.getTraceToken(taskId), taskId, RequestType.ACKNOWLEDGE_AND_GET_NEW_DYNAMIC_FILTER_DOMAINS, asyncResponse)) { return; @@ -397,6 +412,22 @@ public void pruneCatalogs(Set catalogHandles) taskManager.pruneCatalogs(catalogHandles); } + private boolean failRequestIfInvalid(AsyncResponse asyncResponse) + { + if (!startupStatus.isStartupComplete()) { + // When worker node is restarted after a crash, coordinator may be still unaware of the situation and may attempt to schedule tasks on it. + // Ideally the coordinator should not schedule tasks on worker that is not ready, but in pipelined execution there is currently no way to move a task. + // Accepting a request too early will likely lead to some failure and HTTP 500 (INTERNAL_SERVER_ERROR) response. The coordinator won't retry on this. + // Send 503 (SERVICE_UNAVAILABLE) so that request is retried. + asyncResponse.resume(Response.status(Status.SERVICE_UNAVAILABLE) + .type(MediaType.TEXT_PLAIN_TYPE) + .entity("The server is not fully started yet ") + .build()); + return true; + } + return false; + } + private boolean injectFailure( Optional traceToken, TaskId taskId, diff --git a/core/trino-main/src/main/java/io/trino/server/testing/TestingTrinoServer.java b/core/trino-main/src/main/java/io/trino/server/testing/TestingTrinoServer.java index afcf089940438..a62bcd5130216 100644 --- a/core/trino-main/src/main/java/io/trino/server/testing/TestingTrinoServer.java +++ b/core/trino-main/src/main/java/io/trino/server/testing/TestingTrinoServer.java @@ -83,6 +83,7 @@ import io.trino.server.SessionPropertyDefaults; import io.trino.server.SessionSupplier; import io.trino.server.ShutdownAction; +import io.trino.server.StartupStatus; import io.trino.server.security.CertificateAuthenticatorManager; import io.trino.server.security.ServerSecurityModule; import io.trino.spi.ErrorType; @@ -424,6 +425,7 @@ private TestingTrinoServer( eventListeners.forEach(eventListenerManager::addEventListener); getFutureValue(injector.getInstance(Announcer.class).forceAnnounce()); + injector.getInstance(StartupStatus.class).startupComplete(); refreshNodes(); } diff --git a/testing/trino-testing/pom.xml b/testing/trino-testing/pom.xml index 87a34ea616c0d..fcc74b3b8cbce 100644 --- a/testing/trino-testing/pom.xml +++ b/testing/trino-testing/pom.xml @@ -57,6 +57,11 @@ configuration + + io.airlift + http-server + + io.airlift log @@ -200,6 +205,12 @@ assertj-core + + org.eclipse.jetty + jetty-server + 12.0.9 + + org.jdbi jdbi3-core @@ -246,4 +257,4 @@ test - \ No newline at end of file + diff --git a/testing/trino-testing/src/main/java/io/trino/testing/DistributedQueryRunner.java b/testing/trino-testing/src/main/java/io/trino/testing/DistributedQueryRunner.java index 384ce744baab6..1b76edbe6e69d 100644 --- a/testing/trino-testing/src/main/java/io/trino/testing/DistributedQueryRunner.java +++ b/testing/trino-testing/src/main/java/io/trino/testing/DistributedQueryRunner.java @@ -20,6 +20,7 @@ import com.google.inject.Key; import com.google.inject.Module; import io.airlift.discovery.server.testing.TestingDiscoveryServer; +import io.airlift.http.server.HttpServer; import io.airlift.log.Logger; import io.airlift.log.Logging; import io.opentelemetry.sdk.testing.exporter.InMemorySpanExporter; @@ -57,10 +58,13 @@ import io.trino.sql.planner.Plan; import io.trino.testing.containers.OpenTracingCollector; import io.trino.transaction.TransactionManager; +import org.eclipse.jetty.server.Connector; +import org.eclipse.jetty.server.Server; import org.intellij.lang.annotations.Language; import java.io.IOException; import java.io.UncheckedIOException; +import java.lang.reflect.Field; import java.net.URI; import java.nio.file.Path; import java.util.HashMap; @@ -79,6 +83,7 @@ import static com.google.common.base.Preconditions.checkState; import static com.google.common.base.Throwables.throwIfUnchecked; import static com.google.common.base.Verify.verify; +import static com.google.common.collect.Iterables.getOnlyElement; import static com.google.inject.util.Modules.EMPTY_MODULE; import static io.airlift.log.Level.DEBUG; import static io.airlift.log.Level.ERROR; @@ -88,7 +93,9 @@ import static io.trino.execution.querystats.PlanOptimizersStatsCollector.createPlanOptimizersStatsCollector; import static java.lang.Boolean.parseBoolean; import static java.lang.System.getenv; +import static java.util.Arrays.asList; import static java.util.Objects.requireNonNull; +import static org.assertj.core.api.Assertions.assertThat; public class DistributedQueryRunner implements QueryRunner @@ -100,7 +107,7 @@ public class DistributedQueryRunner private TestingDiscoveryServer discoveryServer; private TestingTrinoServer coordinator; private Optional backupCoordinator; - private Runnable registerNewWorker; + private Consumer> registerNewWorker; private final InMemorySpanExporter spanExporter = InMemorySpanExporter.create(); private final List servers = new CopyOnWriteArrayList<>(); private final List functionBundles = new CopyOnWriteArrayList<>(ImmutableList.of(CustomFunctionBundle.CUSTOM_FUNCTIONS)); @@ -149,11 +156,14 @@ private DistributedQueryRunner( extraCloseables.forEach(closeable -> closer.register(() -> closeUnchecked(closeable))); log.debug("Created TestingDiscoveryServer in %s", nanosSince(discoveryStart)); - registerNewWorker = () -> { + registerNewWorker = additionalWorkerProperties -> { @SuppressWarnings("resource") TestingTrinoServer ignored = createServer( false, - extraProperties, + ImmutableMap.builder() + .putAll(extraProperties) + .putAll(additionalWorkerProperties) + .buildOrThrow(), environment, additionalModule, baseDataDir, @@ -163,7 +173,7 @@ private DistributedQueryRunner( }; for (int i = 0; i < workerCount; i++) { - registerNewWorker.run(); + registerNewWorker.accept(Map.of()); } Map extraCoordinatorProperties = new HashMap<>(); @@ -317,11 +327,37 @@ private static TestingTrinoServer createTestingTrinoServer( public void addServers(int nodeCount) { for (int i = 0; i < nodeCount; i++) { - registerNewWorker.run(); + registerNewWorker.accept(Map.of()); } ensureNodesGloballyVisible(); } + /** + * Simulate worker restart as e.g. in Kubernetes after a pod is killed. + */ + public void restartWorker(TestingTrinoServer server) + throws Exception + { + URI baseUrl = server.getBaseUrl(); + checkState(servers.remove(server), "Server not found: %s", server); + HttpServer workerHttpServer = server.getInstance(Key.get(HttpServer.class)); + // Prevent any HTTP communication with the worker, as if the worker process was killed. + Field serverField = HttpServer.class.getDeclaredField("server"); + serverField.setAccessible(true); + Connector httpConnector = getOnlyElement(asList(((Server) serverField.get(workerHttpServer)).getConnectors())); + httpConnector.stop(); + server.close(); + + Map reusePort = Map.of("http-server.http.port", Integer.toString(baseUrl.getPort())); + registerNewWorker.accept(reusePort); + // Verify the address was reused. + assertThat(servers.stream() + .map(TestingTrinoServer::getBaseUrl) + .filter(baseUrl::equals)) + .hasSize(1); + // Do not wait for new server to be fully registered with other servers + } + private void ensureNodesGloballyVisible() { for (TestingTrinoServer server : servers) { @@ -589,7 +625,7 @@ public final void close() discoveryServer = null; coordinator = null; backupCoordinator = Optional.empty(); - registerNewWorker = () -> { + registerNewWorker = _ -> { throw new IllegalStateException("Already closed"); }; servers.clear(); diff --git a/testing/trino-tests/src/test/java/io/trino/tests/TestWorkerRestart.java b/testing/trino-tests/src/test/java/io/trino/tests/TestWorkerRestart.java new file mode 100644 index 0000000000000..ffdbccad86c13 --- /dev/null +++ b/testing/trino-tests/src/test/java/io/trino/tests/TestWorkerRestart.java @@ -0,0 +1,174 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.tests; + +import io.trino.execution.QueryManager; +import io.trino.server.BasicQueryInfo; +import io.trino.server.testing.TestingTrinoServer; +import io.trino.testing.DistributedQueryRunner; +import io.trino.testing.MaterializedResult; +import io.trino.tests.tpch.TpchQueryRunner; +import org.junit.jupiter.api.RepeatedTest; +import org.junit.jupiter.api.Timeout; +import org.junit.jupiter.api.parallel.Execution; + +import java.util.concurrent.ExecutionException; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Future; + +import static com.google.common.collect.MoreCollectors.onlyElement; +import static io.airlift.concurrent.Threads.daemonThreadsNamed; +import static io.trino.execution.QueryState.RUNNING; +import static io.trino.testing.assertions.Assert.assertEventually; +import static java.util.UUID.randomUUID; +import static java.util.concurrent.Executors.newCachedThreadPool; +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static org.junit.jupiter.api.parallel.ExecutionMode.SAME_THREAD; + +/** + * Test that tasks are cleanly rejected when a node restarts (same node ID, different node instance ID). + */ +@Execution(SAME_THREAD) // run single threaded to avoid creating multiple query runners at once +public class TestWorkerRestart +{ + // When working with the test locally it's practical to run multiple iterations at once. + private static final int TEST_ITERATIONS = 1; + + /** + * Test that query passes even if worker is restarted just before query. + */ + @RepeatedTest(TEST_ITERATIONS) + @Timeout(90) + public void testRestartBeforeQuery() + throws Exception + { + try (DistributedQueryRunner queryRunner = TpchQueryRunner.builder().build(); + ExecutorService executor = newCachedThreadPool(daemonThreadsNamed(getClass().getSimpleName() + "-%d"))) { + try { + // Ensure everything initialized + assertThat((long) queryRunner.execute("SELECT count(*) FROM tpch.tiny.lineitem").getOnlyValue()) + .isEqualTo(60_175); + + restartWorker(queryRunner); + // Even though the worker is restarted before we send a query, it is not fully announced to the coordinator. + // Coordinator will still try to send the query to the worker thinking it is the previous instance of it. + Future future = executor.submit(() -> queryRunner.execute("SELECT count(*) FROM tpch.sf1.lineitem -- " + randomUUID())); + future.get(); // query should succeed + + // Ensure that the restarted worker is able to serve queries. + assertThat((long) queryRunner.execute("SELECT count(*) FROM tpch.tiny.lineitem").getOnlyValue()) + .isEqualTo(60_175); + } + finally { + cancelQueries(queryRunner); + } + } + } + + /** + * Test that query fails with when worker crashes during its execution, but next query (e.g. retried query) succeeds without issues. + */ + @RepeatedTest(TEST_ITERATIONS) + @Timeout(90) + public void testRestartDuringQuery() + throws Exception + { + try (DistributedQueryRunner queryRunner = TpchQueryRunner.builder().build(); + ExecutorService executor = newCachedThreadPool(daemonThreadsNamed(getClass().getSimpleName() + "-%d"))) { + try { + // Ensure everything initialized + assertThat((long) queryRunner.execute("SELECT count(*) FROM tpch.tiny.lineitem").getOnlyValue()) + .isEqualTo(60_175); + + String sql = "SELECT count(*) FROM tpch.sf1000000000.lineitem -- " + randomUUID(); + Future future = executor.submit(() -> queryRunner.execute(sql)); + waitForQueryStart(queryRunner, sql); + restartWorker(queryRunner); + assertThatThrownBy(future::get) + .isInstanceOf(ExecutionException.class) + .cause().hasMessageFindingMatch("^Expected response code from \\S+ to be 200, but was 500" + + "|Error fetching \\S+: Expected response code to be 200, but was 500"); + + // Ensure that the restarted worker is able to serve queries. + assertThat((long) queryRunner.execute("SELECT count(*) FROM tpch.tiny.lineitem").getOnlyValue()) + .isEqualTo(60_175); + } + finally { + cancelQueries(queryRunner); + } + } + } + + /** + * Test that query passes if a worker crashed before query started but it still potentially starting up when query is being scheduled. + */ + @RepeatedTest(TEST_ITERATIONS) + @Timeout(90) + public void testStartDuringQuery() + throws Exception + { + try (DistributedQueryRunner queryRunner = TpchQueryRunner.builder().build(); + ExecutorService executor = newCachedThreadPool(daemonThreadsNamed(getClass().getSimpleName() + "-%d"))) { + try { + // Ensure everything initialized + assertThat((long) queryRunner.execute("SELECT count(*) FROM tpch.tiny.lineitem").getOnlyValue()) + .isEqualTo(60_175); + + TestingTrinoServer worker = queryRunner.getServers().stream() + .filter(server -> !server.isCoordinator()) + .findFirst().orElseThrow(); + worker.close(); + Future future = executor.submit(() -> queryRunner.execute("SELECT count(*) FROM tpch.tiny.lineitem -- " + randomUUID())); + // the worker is shut down already, but restartWorker() will reuse its address + queryRunner.restartWorker(worker); + future.get(); // query should succeed + + // Ensure that the restarted worker is able to serve queries. + assertThat((long) queryRunner.execute("SELECT count(*) FROM tpch.tiny.lineitem").getOnlyValue()) + .isEqualTo(60_175); + } + finally { + cancelQueries(queryRunner); + } + } + } + + private static void waitForQueryStart(DistributedQueryRunner queryRunner, String sql) + { + assertEventually(() -> { + BasicQueryInfo queryInfo = queryRunner.getCoordinator().getQueryManager().getQueries().stream() + .filter(query -> query.getQuery().equals(sql)) + .collect(onlyElement()); + assertThat(queryInfo.getState()).isEqualTo(RUNNING); + }); + } + + private static void restartWorker(DistributedQueryRunner queryRunner) + throws Exception + { + TestingTrinoServer worker = queryRunner.getServers().stream() + .filter(server -> !server.isCoordinator()) + .findFirst().orElseThrow(); + queryRunner.restartWorker(worker); + } + + private static void cancelQueries(DistributedQueryRunner queryRunner) + { + QueryManager queryManager = queryRunner.getCoordinator().getQueryManager(); + queryManager.getQueries().stream() + .map(BasicQueryInfo::getQueryId) + .forEach(queryManager::cancelQuery); + } +}