diff --git a/gateway-ha/src/main/java/io/trino/gateway/ha/config/HaGatewayConfiguration.java b/gateway-ha/src/main/java/io/trino/gateway/ha/config/HaGatewayConfiguration.java index 83bba6172..cf3b557a9 100644 --- a/gateway-ha/src/main/java/io/trino/gateway/ha/config/HaGatewayConfiguration.java +++ b/gateway-ha/src/main/java/io/trino/gateway/ha/config/HaGatewayConfiguration.java @@ -37,6 +37,7 @@ public class HaGatewayConfiguration private List extraWhitelistPaths = new ArrayList<>(); private OAuth2GatewayCookieConfiguration oauth2GatewayCookieConfiguration = new OAuth2GatewayCookieConfiguration(); private GatewayCookieConfiguration gatewayCookieConfiguration = new GatewayCookieConfiguration(); + private List customStatementPaths = new ArrayList<>(); // List of Modules with FQCN (Fully Qualified Class Name) private List modules; @@ -205,4 +206,14 @@ public void setManagedApps(List managedApps) { this.managedApps = managedApps; } + + public List getCustomStatementPaths() + { + return customStatementPaths; + } + + public void setCustomStatementPaths(List customStatementPaths) + { + this.customStatementPaths = customStatementPaths; + } } diff --git a/gateway-ha/src/main/java/io/trino/gateway/ha/handler/DropWizardProxyHandlerStats.java b/gateway-ha/src/main/java/io/trino/gateway/ha/handler/DropWizardProxyHandlerStats.java new file mode 100644 index 000000000..3182cd381 --- /dev/null +++ b/gateway-ha/src/main/java/io/trino/gateway/ha/handler/DropWizardProxyHandlerStats.java @@ -0,0 +1,66 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.gateway.ha.handler; + +import com.codahale.metrics.Meter; +import io.airlift.stats.CounterStat; +import io.dropwizard.core.setup.Environment; +import org.weakref.jmx.MBeanExporter; +import org.weakref.jmx.Managed; +import org.weakref.jmx.Nested; + +import java.lang.management.ManagementFactory; + +import static java.util.Objects.requireNonNull; + +public final class DropWizardProxyHandlerStats + implements ProxyHandlerStats +{ + // Airlift + private final CounterStat requestCount = new CounterStat(); + + // Dropwizard + private final Meter requestMeter; + + private DropWizardProxyHandlerStats(Environment environment, String metricsName) + { + this.requestMeter = requireNonNull(environment, "environment is null") + .metrics() + .meter(requireNonNull(metricsName, "metricsName is null")); + } + + @Override + public void recordRequest() + { + requestCount.update(1); + requestMeter.mark(); + } + + // Replace this with Guice bind after migrated to Airlift + public static ProxyHandlerStats create(Environment environment, String metricsName) + { + ProxyHandlerStats proxyHandlerStats = new DropWizardProxyHandlerStats(environment, metricsName); + MBeanExporter exporter = new MBeanExporter(ManagementFactory.getPlatformMBeanServer()); + exporter.exportWithGeneratedName(proxyHandlerStats); + return proxyHandlerStats; + } + + @Override + @Managed + @Nested + public CounterStat getRequestCount() + { + return requestCount; + } +} diff --git a/gateway-ha/src/main/java/io/trino/gateway/ha/handler/ProxyHandlerStats.java b/gateway-ha/src/main/java/io/trino/gateway/ha/handler/ProxyHandlerStats.java index e5f21c001..b9a4ecae1 100644 --- a/gateway-ha/src/main/java/io/trino/gateway/ha/handler/ProxyHandlerStats.java +++ b/gateway-ha/src/main/java/io/trino/gateway/ha/handler/ProxyHandlerStats.java @@ -13,51 +13,15 @@ */ package io.trino.gateway.ha.handler; -import com.codahale.metrics.Meter; import io.airlift.stats.CounterStat; -import io.dropwizard.core.setup.Environment; -import org.weakref.jmx.MBeanExporter; import org.weakref.jmx.Managed; import org.weakref.jmx.Nested; -import java.lang.management.ManagementFactory; - -import static java.util.Objects.requireNonNull; - -public final class ProxyHandlerStats +public interface ProxyHandlerStats { - // Airlift - private final CounterStat requestCount = new CounterStat(); - - // Dropwizard - private final Meter requestMeter; - - private ProxyHandlerStats(Environment environment, String metricsName) - { - this.requestMeter = requireNonNull(environment, "environment is null") - .metrics() - .meter(requireNonNull(metricsName, "metricsName is null")); - } - - public void recordRequest() - { - requestCount.update(1); - requestMeter.mark(); - } - - // Replace this with Guice bind after migrated to Airlift - public static ProxyHandlerStats create(Environment environment, String metricsName) - { - ProxyHandlerStats proxyHandlerStats = new ProxyHandlerStats(environment, metricsName); - MBeanExporter exporter = new MBeanExporter(ManagementFactory.getPlatformMBeanServer()); - exporter.exportWithGeneratedName(proxyHandlerStats); - return proxyHandlerStats; - } + void recordRequest(); @Managed @Nested - public CounterStat getRequestCount() - { - return requestCount; - } + CounterStat getRequestCount(); } diff --git a/gateway-ha/src/main/java/io/trino/gateway/ha/handler/QueryIdCachingProxyHandler.java b/gateway-ha/src/main/java/io/trino/gateway/ha/handler/QueryIdCachingProxyHandler.java index 4a401a35e..b1cc4a772 100644 --- a/gateway-ha/src/main/java/io/trino/gateway/ha/handler/QueryIdCachingProxyHandler.java +++ b/gateway-ha/src/main/java/io/trino/gateway/ha/handler/QueryIdCachingProxyHandler.java @@ -44,6 +44,7 @@ import java.util.Optional; import java.util.regex.Matcher; import java.util.regex.Pattern; +import java.util.stream.Stream; import static com.google.common.base.Strings.isNullOrEmpty; @@ -90,6 +91,7 @@ public class QueryIdCachingProxyHandler private final ProxyHandlerStats proxyHandlerStats; private final List extraWhitelistPaths; + private final List customStatementPaths; private final String applicationEndpoint; private final boolean cookiesEnabled; @@ -99,18 +101,20 @@ public QueryIdCachingProxyHandler( RoutingGroupSelector routingGroupSelector, int serverApplicationPort, ProxyHandlerStats proxyHandlerStats, - List extraWhitelistPaths) + List extraWhitelistPaths, + List customStatementPaths) { this.proxyHandlerStats = proxyHandlerStats; this.routingManager = routingManager; this.routingGroupSelector = routingGroupSelector; this.queryHistoryManager = queryHistoryManager; this.extraWhitelistPaths = extraWhitelistPaths; + this.customStatementPaths = customStatementPaths; this.applicationEndpoint = "http://localhost:" + serverApplicationPort; cookiesEnabled = GatewayCookieConfigurationPropertiesProvider.getInstance().isEnabled(); } - protected static String extractQueryIdIfPresent(String path, String queryParams) + protected String extractQueryIdIfPresent(String path, String queryParams) { if (path == null) { return null; @@ -118,17 +122,20 @@ protected static String extractQueryIdIfPresent(String path, String queryParams) String queryId = null; log.debug("trying to extract query id from path [%s] or queryString [%s]", path, queryParams); - if (path.startsWith(V1_STATEMENT_PATH) || path.startsWith(V1_QUERY_PATH)) { - String[] tokens = path.split("/"); - if (tokens.length >= 4) { + Optional pathPrefix = Stream.concat(Stream.of(V1_STATEMENT_PATH, V1_QUERY_PATH), customStatementPaths.stream()).filter(path::startsWith).findFirst(); + if (pathPrefix.isPresent()) { + String reducedPath = path.replace(pathPrefix.orElseThrow(), ""); + + String[] tokens = reducedPath.split("/"); + if (tokens.length >= 2) { if (path.contains("queued") || path.contains("scheduled") || path.contains("executing") || path.contains("partialCancel")) { - queryId = tokens[4]; + queryId = tokens[2]; } else { - queryId = tokens[3]; + queryId = tokens[1]; } } } @@ -246,7 +253,8 @@ protected String extractQueryIdIfPresent(HttpServletRequest request) public void preConnectionHook(HttpServletRequest request, Request proxyRequest) { if (request.getMethod().equals(HttpMethod.POST) - && request.getRequestURI().startsWith(V1_STATEMENT_PATH)) { + && (request.getRequestURI().startsWith(V1_STATEMENT_PATH) + || customStatementPaths.stream().anyMatch(request.getRequestURI()::startsWith))) { proxyHandlerStats.recordRequest(); try { String requestBody = CharStreams.toString(request.getReader()); @@ -275,7 +283,8 @@ private boolean isPathWhiteListed(String path) || path.startsWith(V1_NODE_PATH) || path.startsWith(UI_API_STATS_PATH) || path.startsWith(OAUTH_PATH) - || extraWhitelistPaths.stream().anyMatch(s -> path.startsWith(s)); + || extraWhitelistPaths.stream().anyMatch(path::startsWith) + || customStatementPaths.stream().anyMatch(path::startsWith); } @Override @@ -378,10 +387,11 @@ protected void postConnectionHook( Callback callback) { try { - if (request.getRequestURI().startsWith(V1_STATEMENT_PATH) && request.getMethod().equals(HttpMethod.POST)) { + String requestPath = request.getRequestURI(); + if ((requestPath.startsWith(V1_STATEMENT_PATH) || customStatementPaths.stream().anyMatch(requestPath::startsWith)) && request.getMethod().equals(HttpMethod.POST)) { recordBackendForQueryId(request, response, buffer); } - else if (cookiesEnabled && request.getRequestURI().startsWith(OAuth2GatewayCookie.OAUTH2_PATH) + else if (cookiesEnabled && requestPath.startsWith(OAuth2GatewayCookie.OAUTH2_PATH) && !(request.getCookies() != null && Arrays.stream(request.getCookies()).anyMatch(c -> c.getName().equals(OAuth2GatewayCookie.NAME)))) { GatewayCookie oauth2Cookie = new OAuth2GatewayCookie(request.getHeader(PROXY_TARGET_HEADER)); diff --git a/gateway-ha/src/main/java/io/trino/gateway/ha/module/HaGatewayProviderModule.java b/gateway-ha/src/main/java/io/trino/gateway/ha/module/HaGatewayProviderModule.java index 8ee294d5a..2b78f8f0e 100644 --- a/gateway-ha/src/main/java/io/trino/gateway/ha/module/HaGatewayProviderModule.java +++ b/gateway-ha/src/main/java/io/trino/gateway/ha/module/HaGatewayProviderModule.java @@ -33,6 +33,7 @@ import io.trino.gateway.ha.config.RequestRouterConfiguration; import io.trino.gateway.ha.config.RoutingRulesConfiguration; import io.trino.gateway.ha.config.UserConfiguration; +import io.trino.gateway.ha.handler.DropWizardProxyHandlerStats; import io.trino.gateway.ha.handler.ProxyHandlerStats; import io.trino.gateway.ha.handler.QueryIdCachingProxyHandler; import io.trino.gateway.ha.router.BackendStateManager; @@ -72,6 +73,7 @@ public class HaGatewayProviderModule private final BackendStateManager backendStateConnectionManager; private final AuthFilter authenticationFilter; private final List extraWhitelistPaths; + private final List customStatementPaths; private final HaGatewayConfiguration configuration; private final Environment environment; @@ -94,6 +96,7 @@ public HaGatewayProviderModule(HaGatewayConfiguration configuration, Environment OAuth2GatewayCookieConfigurationPropertiesProvider oAuth2GatewayCookieConfigurationPropertiesProvider = OAuth2GatewayCookieConfigurationPropertiesProvider.getInstance(); oAuth2GatewayCookieConfigurationPropertiesProvider.initialize(configuration.getOauth2GatewayCookieConfiguration()); + customStatementPaths = configuration.getCustomStatementPaths(); } private LbOAuthManager getOAuthManager(HaGatewayConfiguration configuration) @@ -149,9 +152,9 @@ private ChainedAuthFilter getAuthenticationFilters(AuthenticationConfiguration c } private ProxyHandler getProxyHandler(QueryHistoryManager queryHistoryManager, - RoutingManager routingManager) + RoutingManager routingManager) { - ProxyHandlerStats proxyHandlerStats = ProxyHandlerStats.create( + ProxyHandlerStats proxyHandlerStats = DropWizardProxyHandlerStats.create( environment, configuration.getRequestRouter().getName() + ".requests"); @@ -170,7 +173,8 @@ private ProxyHandler getProxyHandler(QueryHistoryManager queryHistoryManager, routingGroupSelector, getApplicationPort(), proxyHandlerStats, - extraWhitelistPaths); + extraWhitelistPaths, + customStatementPaths); } private int getApplicationPort() @@ -211,7 +215,7 @@ private AuthFilter getAuthFilter(HaGatewayConfiguration configuration) @Provides @Singleton public ProxyServer provideGateway(QueryHistoryManager queryHistoryManager, - RoutingManager routingManager) + RoutingManager routingManager) { ProxyServer gateway = null; if (configuration.getRequestRouter() != null) { diff --git a/gateway-ha/src/main/java/io/trino/gateway/ha/router/CachingRoutingManager.java b/gateway-ha/src/main/java/io/trino/gateway/ha/router/CachingRoutingManager.java new file mode 100644 index 000000000..7049fef25 --- /dev/null +++ b/gateway-ha/src/main/java/io/trino/gateway/ha/router/CachingRoutingManager.java @@ -0,0 +1,205 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.gateway.ha.router; + +import com.google.common.cache.CacheBuilder; +import com.google.common.cache.CacheLoader; +import com.google.common.cache.LoadingCache; +import io.airlift.log.Logger; +import io.trino.gateway.ha.clustermonitor.ClusterStats; +import io.trino.gateway.ha.config.ProxyBackendConfiguration; +import io.trino.gateway.proxyserver.ProxyServerConfiguration; +import jakarta.ws.rs.HttpMethod; + +import java.net.HttpURLConnection; +import java.net.URL; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Random; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.Future; +import java.util.concurrent.TimeUnit; + +/** + * This class performs health check, stats counts for each backend and provides a backend given + * request object. Default implementation comes here. + */ +public abstract class CachingRoutingManager + implements RoutingManager +{ + private static final Random RANDOM = new Random(); + private static final Logger log = Logger.get(CachingRoutingManager.class); + private final LoadingCache queryIdBackendCache; + private final ExecutorService executorService = Executors.newFixedThreadPool(5); + private final GatewayBackendManager gatewayBackendManager; + private final ConcurrentHashMap backendToHealth; + + public CachingRoutingManager(GatewayBackendManager gatewayBackendManager) + { + this.gatewayBackendManager = gatewayBackendManager; + queryIdBackendCache = + CacheBuilder.newBuilder() + .maximumSize(10000) + .expireAfterAccess(30, TimeUnit.MINUTES) + .build( + new CacheLoader() + { + @Override + public String load(String queryId) + { + return findBackendForUnknownQueryId(queryId); + } + }); + + this.backendToHealth = new ConcurrentHashMap(); + } + + protected GatewayBackendManager getGatewayBackendManager() + { + return gatewayBackendManager; + } + + @Override + public void setBackendForQueryId(String queryId, String backend) + { + queryIdBackendCache.put(queryId, backend); + } + + /** + * Performs routing to an adhoc backend. + */ + @Override + public String provideAdhocBackend(String user) + { + List backends = this.gatewayBackendManager.getActiveAdhocBackends(); + backends.removeIf(backend -> isBackendNotHealthy(backend.getName())); + if (backends.size() == 0) { + throw new IllegalStateException("Number of active backends found zero"); + } + int backendId = Math.abs(RANDOM.nextInt()) % backends.size(); + return backends.get(backendId).getProxyTo(); + } + + /** + * Performs routing to a given cluster group. This falls back to an adhoc backend, if no scheduled + * backend is found. + */ + @Override + public String provideBackendForRoutingGroup(String routingGroup, String user) + { + List backends = + gatewayBackendManager.getActiveBackends(routingGroup); + backends.removeIf(backend -> isBackendNotHealthy(backend.getName())); + if (backends.isEmpty()) { + return provideAdhocBackend(user); + } + int backendId = Math.abs(RANDOM.nextInt()) % backends.size(); + return backends.get(backendId).getProxyTo(); + } + + /** + * Performs cache look up, if a backend not found, it checks with all backends and tries to find + * out which backend has info about given query id. + */ + @Override + public String findBackendForQueryId(String queryId) + { + String backendAddress = null; + try { + backendAddress = queryIdBackendCache.get(queryId); + } + catch (ExecutionException e) { + log.error("Exception while loading queryId from cache %s", e.getLocalizedMessage()); + } + return backendAddress; + } + + @Override + public void updateBackEndHealth(String backendId, Boolean value) + { + log.info("backend %s isHealthy %s", backendId, value); + backendToHealth.put(backendId, value); + } + + @Override + public void updateBackEndStats(List stats) + { + for (ClusterStats clusterStats : stats) { + updateBackEndHealth(clusterStats.clusterId(), clusterStats.healthy()); + } + } + + /** + * This tries to find out which backend may have info about given query id. If not found returns + * the first healthy backend. + */ + @Override + public String findBackendForUnknownQueryId(String queryId) + { + List backends = gatewayBackendManager.getAllBackends(); + + Map> responseCodes = new HashMap<>(); + try { + for (ProxyServerConfiguration backend : backends) { + String target = backend.getProxyTo() + "/v1/query/" + queryId; + + Future call = + executorService.submit( + () -> { + URL url = new URL(target); + HttpURLConnection conn = (HttpURLConnection) url.openConnection(); + conn.setConnectTimeout((int) TimeUnit.SECONDS.toMillis(5)); + conn.setReadTimeout((int) TimeUnit.SECONDS.toMillis(5)); + conn.setRequestMethod(HttpMethod.HEAD); + return conn.getResponseCode(); + }); + responseCodes.put(backend.getProxyTo(), call); + } + for (Map.Entry> entry : responseCodes.entrySet()) { + if (entry.getValue().isDone()) { + int responseCode = entry.getValue().get(); + if (responseCode == 200) { + log.info("Found query [%s] on backend [%s]", queryId, entry.getKey()); + setBackendForQueryId(queryId, entry.getKey()); + return entry.getKey(); + } + } + } + } + catch (Exception e) { + log.warn("Query id [%s] not found", queryId); + } + // Fallback on first active backend if queryId mapping not found. + return gatewayBackendManager.getActiveAdhocBackends().get(0).getProxyTo(); + } + + // Predicate helper function to remove the backends from the list + // We are returning the unhealthy (not healthy) + private boolean isBackendNotHealthy(String backendId) + { + if (backendToHealth.isEmpty()) { + CachingRoutingManager.log.error("backends can not be empty"); + return true; + } + Boolean isHealthy = backendToHealth.get(backendId); + if (isHealthy == null) { + return true; + } + return !isHealthy; + } +} diff --git a/gateway-ha/src/main/java/io/trino/gateway/ha/router/RoutingManager.java b/gateway-ha/src/main/java/io/trino/gateway/ha/router/RoutingManager.java index f40510f25..3b7d1e40a 100644 --- a/gateway-ha/src/main/java/io/trino/gateway/ha/router/RoutingManager.java +++ b/gateway-ha/src/main/java/io/trino/gateway/ha/router/RoutingManager.java @@ -13,185 +13,23 @@ */ package io.trino.gateway.ha.router; -import com.google.common.cache.CacheBuilder; -import com.google.common.cache.CacheLoader; -import com.google.common.cache.LoadingCache; -import io.airlift.log.Logger; import io.trino.gateway.ha.clustermonitor.ClusterStats; -import io.trino.gateway.ha.config.ProxyBackendConfiguration; -import io.trino.gateway.proxyserver.ProxyServerConfiguration; -import jakarta.ws.rs.HttpMethod; -import java.net.HttpURLConnection; -import java.net.URL; -import java.util.HashMap; import java.util.List; -import java.util.Map; -import java.util.Random; -import java.util.concurrent.ConcurrentHashMap; -import java.util.concurrent.ExecutionException; -import java.util.concurrent.ExecutorService; -import java.util.concurrent.Executors; -import java.util.concurrent.Future; -import java.util.concurrent.TimeUnit; -/** - * This class performs health check, stats counts for each backend and provides a backend given - * request object. Default implementation comes here. - */ -public abstract class RoutingManager +public interface RoutingManager { - private static final Random RANDOM = new Random(); - private static final Logger log = Logger.get(RoutingManager.class); - private final LoadingCache queryIdBackendCache; - private final ExecutorService executorService = Executors.newFixedThreadPool(5); - private final GatewayBackendManager gatewayBackendManager; - private final ConcurrentHashMap backendToHealth; - - public RoutingManager(GatewayBackendManager gatewayBackendManager) - { - this.gatewayBackendManager = gatewayBackendManager; - queryIdBackendCache = - CacheBuilder.newBuilder() - .maximumSize(10000) - .expireAfterAccess(30, TimeUnit.MINUTES) - .build( - new CacheLoader() - { - @Override - public String load(String queryId) - { - return findBackendForUnknownQueryId(queryId); - } - }); - - this.backendToHealth = new ConcurrentHashMap(); - } - - protected GatewayBackendManager getGatewayBackendManager() - { - return gatewayBackendManager; - } - - public void setBackendForQueryId(String queryId, String backend) - { - queryIdBackendCache.put(queryId, backend); - } - - /** - * Performs routing to an adhoc backend. - */ - public String provideAdhocBackend(String user) - { - List backends = this.gatewayBackendManager.getActiveAdhocBackends(); - backends.removeIf(backend -> isBackendNotHealthy(backend.getName())); - if (backends.size() == 0) { - throw new IllegalStateException("Number of active backends found zero"); - } - int backendId = Math.abs(RANDOM.nextInt()) % backends.size(); - return backends.get(backendId).getProxyTo(); - } - - /** - * Performs routing to a given cluster group. This falls back to an adhoc backend, if no scheduled - * backend is found. - */ - public String provideBackendForRoutingGroup(String routingGroup, String user) - { - List backends = - gatewayBackendManager.getActiveBackends(routingGroup); - backends.removeIf(backend -> isBackendNotHealthy(backend.getName())); - if (backends.isEmpty()) { - return provideAdhocBackend(user); - } - int backendId = Math.abs(RANDOM.nextInt()) % backends.size(); - return backends.get(backendId).getProxyTo(); - } - - /** - * Performs cache look up, if a backend not found, it checks with all backends and tries to find - * out which backend has info about given query id. - */ - public String findBackendForQueryId(String queryId) - { - String backendAddress = null; - try { - backendAddress = queryIdBackendCache.get(queryId); - } - catch (ExecutionException e) { - log.error("Exception while loading queryId from cache %s", e.getLocalizedMessage()); - } - return backendAddress; - } + void setBackendForQueryId(String queryId, String backend); - public void updateBackEndHealth(String backendId, Boolean value) - { - log.info("backend %s isHealthy %s", backendId, value); - backendToHealth.put(backendId, value); - } + String provideAdhocBackend(String user); - public void updateBackEndStats(List stats) - { - for (ClusterStats clusterStats : stats) { - updateBackEndHealth(clusterStats.clusterId(), clusterStats.healthy()); - } - } + String provideBackendForRoutingGroup(String routingGroup, String user); - /** - * This tries to find out which backend may have info about given query id. If not found returns - * the first healthy backend. - */ - protected String findBackendForUnknownQueryId(String queryId) - { - List backends = gatewayBackendManager.getAllBackends(); + String findBackendForQueryId(String queryId); - Map> responseCodes = new HashMap<>(); - try { - for (ProxyServerConfiguration backend : backends) { - String target = backend.getProxyTo() + "/v1/query/" + queryId; + void updateBackEndHealth(String backendId, Boolean value); - Future call = - executorService.submit( - () -> { - URL url = new URL(target); - HttpURLConnection conn = (HttpURLConnection) url.openConnection(); - conn.setConnectTimeout((int) TimeUnit.SECONDS.toMillis(5)); - conn.setReadTimeout((int) TimeUnit.SECONDS.toMillis(5)); - conn.setRequestMethod(HttpMethod.HEAD); - return conn.getResponseCode(); - }); - responseCodes.put(backend.getProxyTo(), call); - } - for (Map.Entry> entry : responseCodes.entrySet()) { - if (entry.getValue().isDone()) { - int responseCode = entry.getValue().get(); - if (responseCode == 200) { - log.info("Found query [%s] on backend [%s]", queryId, entry.getKey()); - setBackendForQueryId(queryId, entry.getKey()); - return entry.getKey(); - } - } - } - } - catch (Exception e) { - log.warn("Query id [%s] not found", queryId); - } - // Fallback on first active backend if queryId mapping not found. - return gatewayBackendManager.getActiveAdhocBackends().get(0).getProxyTo(); - } + void updateBackEndStats(List stats); - // Predicate helper function to remove the backends from the list - // We are returning the unhealthy (not healthy) - private boolean isBackendNotHealthy(String backendId) - { - if (backendToHealth.isEmpty()) { - log.error("backends can not be empty"); - return true; - } - Boolean isHealthy = backendToHealth.get(backendId); - if (isHealthy == null) { - return true; - } - return !isHealthy; - } + String findBackendForUnknownQueryId(String queryId); } diff --git a/gateway-ha/src/main/java/io/trino/gateway/ha/router/StochasticRoutingManager.java b/gateway-ha/src/main/java/io/trino/gateway/ha/router/StochasticRoutingManager.java index 269d3640c..39fb42fd1 100644 --- a/gateway-ha/src/main/java/io/trino/gateway/ha/router/StochasticRoutingManager.java +++ b/gateway-ha/src/main/java/io/trino/gateway/ha/router/StochasticRoutingManager.java @@ -17,7 +17,7 @@ import io.airlift.log.Logger; public class StochasticRoutingManager - extends RoutingManager + extends CachingRoutingManager { private static final Logger log = Logger.get(StochasticRoutingManager.class); QueryHistoryManager queryHistoryManager; @@ -30,7 +30,7 @@ public StochasticRoutingManager( } @Override - protected String findBackendForUnknownQueryId(String queryId) + public String findBackendForUnknownQueryId(String queryId) { String backend; backend = queryHistoryManager.getBackendForQueryId(queryId); diff --git a/gateway-ha/src/test/java/io/trino/gateway/ha/handler/TestQueryIdCachingProxyHandler.java b/gateway-ha/src/test/java/io/trino/gateway/ha/handler/TestQueryIdCachingProxyHandler.java index 4e8cfd10f..bc6c5c00f 100644 --- a/gateway-ha/src/test/java/io/trino/gateway/ha/handler/TestQueryIdCachingProxyHandler.java +++ b/gateway-ha/src/test/java/io/trino/gateway/ha/handler/TestQueryIdCachingProxyHandler.java @@ -13,52 +13,98 @@ */ package io.trino.gateway.ha.handler; +import com.google.common.collect.ImmutableList; +import io.trino.gateway.ha.config.GatewayCookieConfiguration; +import io.trino.gateway.ha.config.GatewayCookieConfigurationPropertiesProvider; +import io.trino.gateway.ha.router.QueryHistoryManager; +import io.trino.gateway.ha.router.RoutingGroupSelector; +import io.trino.gateway.ha.router.RoutingManager; +import io.trino.gateway.ha.router.TestingProxyHandlerStats; +import io.trino.gateway.ha.router.TestingQueryManager; +import io.trino.gateway.ha.router.TestingRoutingManager; import jakarta.servlet.http.HttpServletRequest; +import jakarta.servlet.http.HttpServletResponse; +import jakarta.ws.rs.HttpMethod; import org.eclipse.jetty.client.HttpClient; import org.eclipse.jetty.client.api.Request; +import org.junit.jupiter.api.BeforeAll; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.TestInstance; import org.junit.jupiter.api.TestInstance.Lifecycle; import org.mockito.Mockito; +import java.io.BufferedReader; import java.io.IOException; +import java.io.Reader; +import java.io.StringReader; +import java.util.List; -import static io.trino.gateway.ha.handler.QueryIdCachingProxyHandler.extractQueryIdIfPresent; +import static io.trino.gateway.ha.handler.QueryIdCachingProxyHandler.PROXY_TARGET_HEADER; import static java.lang.String.format; +import static java.nio.charset.StandardCharsets.UTF_8; import static org.assertj.core.api.Assertions.assertThat; +import static org.eclipse.jetty.util.Callback.NOOP; +import static org.mockito.Mockito.when; @TestInstance(Lifecycle.PER_CLASS) public class TestQueryIdCachingProxyHandler { + QueryIdCachingProxyHandler queryIdCachingProxyHandler; + String customStatementPath = "/custom/statement/path"; + String backend = "backend"; + List customStatementPaths = ImmutableList.of(customStatementPath); + QueryHistoryManager queryHistoryManager = new TestingQueryManager(); + RoutingManager routingManager = new TestingRoutingManager(backend); + RoutingGroupSelector routingGroupSelector = RoutingGroupSelector.byRoutingGroupHeader(); + ProxyHandlerStats proxyHandlerStats = new TestingProxyHandlerStats(); + + @BeforeAll + public void setup() + { + GatewayCookieConfigurationPropertiesProvider.getInstance().initialize(new GatewayCookieConfiguration()); + queryIdCachingProxyHandler = new QueryIdCachingProxyHandler( + queryHistoryManager, + routingManager, + routingGroupSelector, + 80, + proxyHandlerStats, + ImmutableList.of(), + customStatementPaths); + } + @Test public void testExtractQueryIdFromUrl() throws IOException { - assertThat(extractQueryIdIfPresent("/v1/statement/executing/20200416_160256_03078_6b4yt/ya7e884929c67cdf86207a80e7a77ab2166fa2e7b/1368", null)) + assertThat(queryIdCachingProxyHandler.extractQueryIdIfPresent("/v1/statement/executing/20200416_160256_03078_6b4yt/ya7e884929c67cdf86207a80e7a77ab2166fa2e7b/1368", null)) .isEqualTo("20200416_160256_03078_6b4yt"); - assertThat(extractQueryIdIfPresent("/v1/statement/queued/20200416_160256_03078_6b4yt/y0d7620a6941e78d3950798a1085383234258a566/1", null)) + assertThat(queryIdCachingProxyHandler.extractQueryIdIfPresent("/v1/statement/queued/20200416_160256_03078_6b4yt/y0d7620a6941e78d3950798a1085383234258a566/1", null)) .isEqualTo("20200416_160256_03078_6b4yt"); - assertThat(extractQueryIdIfPresent("/ui/api/query/20200416_160256_03078_6b4yt", null)) + assertThat(queryIdCachingProxyHandler.extractQueryIdIfPresent("/ui/api/query/20200416_160256_03078_6b4yt", null)) .isEqualTo("20200416_160256_03078_6b4yt"); - assertThat(extractQueryIdIfPresent("/ui/api/query/20200416_160256_03078_6b4yt/killed", null)) + assertThat(queryIdCachingProxyHandler.extractQueryIdIfPresent("/ui/api/query/20200416_160256_03078_6b4yt/killed", null)) .isEqualTo("20200416_160256_03078_6b4yt"); - assertThat(extractQueryIdIfPresent("/ui/api/query/20200416_160256_03078_6b4yt/preempted", null)) + assertThat(queryIdCachingProxyHandler.extractQueryIdIfPresent("/ui/api/query/20200416_160256_03078_6b4yt/preempted", null)) .isEqualTo("20200416_160256_03078_6b4yt"); - assertThat(extractQueryIdIfPresent("/v1/query/20200416_160256_03078_6b4yt", "pretty")) + assertThat(queryIdCachingProxyHandler.extractQueryIdIfPresent("/v1/query/20200416_160256_03078_6b4yt", "pretty")) .isEqualTo("20200416_160256_03078_6b4yt"); - assertThat(extractQueryIdIfPresent("/ui/troubleshooting", "queryId=20200416_160256_03078_6b4yt")) + assertThat(queryIdCachingProxyHandler.extractQueryIdIfPresent("/ui/troubleshooting", "queryId=20200416_160256_03078_6b4yt")) .isEqualTo("20200416_160256_03078_6b4yt"); - assertThat(extractQueryIdIfPresent("/ui/query.html", "20200416_160256_03078_6b4yt")) + assertThat(queryIdCachingProxyHandler.extractQueryIdIfPresent("/ui/query.html", "20200416_160256_03078_6b4yt")) .isEqualTo("20200416_160256_03078_6b4yt"); - assertThat(extractQueryIdIfPresent("/login", "redirect=%2Fui%2Fapi%2Fquery%2F20200416_160256_03078_6b4yt")) + assertThat(queryIdCachingProxyHandler.extractQueryIdIfPresent("/login", "redirect=%2Fui%2Fapi%2Fquery%2F20200416_160256_03078_6b4yt")) .isEqualTo("20200416_160256_03078_6b4yt"); - assertThat(extractQueryIdIfPresent("/ui/api/query/myOtherThing", null)) + assertThat(queryIdCachingProxyHandler.extractQueryIdIfPresent("/ui/api/query/myOtherThing", null)) .isNull(); - assertThat(extractQueryIdIfPresent("/ui/api/query/20200416_blah", "bogus_fictional_param")) + assertThat(queryIdCachingProxyHandler.extractQueryIdIfPresent("/ui/api/query/20200416_blah", "bogus_fictional_param")) .isNull(); - assertThat(extractQueryIdIfPresent("/ui/", "lang=en&p=1&id=0_1_2_a")) + assertThat(queryIdCachingProxyHandler.extractQueryIdIfPresent("/ui/", "lang=en&p=1&id=0_1_2_a")) .isNull(); + assertThat(queryIdCachingProxyHandler.extractQueryIdIfPresent(customStatementPath + "/executing/20200416_160256_03078_6b4yt/ya7e884929c67cdf86207a80e7a77ab2166fa2e7b/1368", null)) + .isEqualTo("20200416_160256_03078_6b4yt"); + assertThat(queryIdCachingProxyHandler.extractQueryIdIfPresent(customStatementPath + "/queued/20200416_160256_03078_6b4yt/y0d7620a6941e78d3950798a1085383234258a566/1", null)) + .isEqualTo("20200416_160256_03078_6b4yt"); } @Test @@ -68,7 +114,7 @@ public void testForwardedHostHeaderOnProxyRequest() String backendServer = "trinocluster"; String backendPort = "80"; HttpServletRequest mockServletRequest = Mockito.mock(HttpServletRequest.class); - Mockito.when(mockServletRequest.getHeader("proxytarget")).thenReturn(format("http://%s:%s", backendServer, backendPort)); + when(mockServletRequest.getHeader("proxytarget")).thenReturn(format("http://%s:%s", backendServer, backendPort)); HttpClient httpClient = new HttpClient(); Request proxyRequest = httpClient.newRequest("http://localhost:80"); QueryIdCachingProxyHandler.setForwardedHostHeaderOnProxyRequest(mockServletRequest, @@ -77,6 +123,76 @@ public void testForwardedHostHeaderOnProxyRequest() .isEqualTo(format("%s:%s", backendServer, backendPort)); } + @Test + public void testPreconnectionHook() + { + String backendServer = "trinocluster"; + String backendPort = "80"; + + HttpServletRequest request = Mockito.mock(HttpServletRequest.class); + when(request.getHeader("proxytarget")).thenReturn(String.format("http://%s:%s", backendServer, backendPort)); + when(request.getRequestURI()).thenReturn(QueryIdCachingProxyHandler.V1_STATEMENT_PATH); + when(request.getMethod()).thenReturn(HttpMethod.POST); + HttpClient httpClient = new HttpClient(); + Request proxyRequest = httpClient.newRequest("http://localhost:80"); + queryIdCachingProxyHandler.preConnectionHook(request, proxyRequest); + assertThat(proxyRequest.getHeaders().get("Host")).isEqualTo(String.format("%s:%s", backendServer, backendPort)); + + when(request.getRequestURI()).thenReturn(customStatementPath); + proxyRequest = httpClient.newRequest("http://localhost:80"); + queryIdCachingProxyHandler.preConnectionHook(request, proxyRequest); + assertThat(proxyRequest.getHeaders().get("Host")).isEqualTo(String.format("%s:%s", backendServer, backendPort)); + + when(request.getRequestURI()).thenReturn("/v1/invalid/statement/path"); + proxyRequest = httpClient.newRequest("http://localhost:80"); + queryIdCachingProxyHandler.preConnectionHook(request, proxyRequest); + assertThat(proxyRequest.getHeaders().get("Host")).isNull(); + } + + @Test + public void testPostConnectionHook() + throws IOException + { + String backend = "trinocluster:80"; + String user = "usr"; + String source = "jdbc"; + String body = "Select 1"; + HttpServletRequest request = Mockito.mock(HttpServletRequest.class); + when(request.getMethod()).thenReturn(HttpMethod.POST); + when(request.getRequestURI()).thenReturn(QueryIdCachingProxyHandler.V1_STATEMENT_PATH); + when(request.getHeader(PROXY_TARGET_HEADER)).thenReturn(backend); + when(request.getHeader(QueryIdCachingProxyHandler.USER_HEADER)).thenReturn(user); + when(request.getHeader(QueryIdCachingProxyHandler.SOURCE_HEADER)).thenReturn(source); + Reader reader = new StringReader(body); + BufferedReader bufferedReader = new BufferedReader(reader); + when(request.getReader()).thenReturn(bufferedReader); + + HttpServletResponse response = Mockito.mock(HttpServletResponse.class); + when(response.getStatus()).thenReturn(200); + + String queryId = "v1_statement_id"; + String responseBody = String.format("{\"id\":\"%s\"}", queryId); + byte[] buffer = responseBody.getBytes(UTF_8); + + when(request.getRequestURI()).thenReturn(QueryIdCachingProxyHandler.V1_STATEMENT_PATH); + queryIdCachingProxyHandler.postConnectionHook(request, response, buffer, 0, buffer.length, NOOP); + assertThat(routingManager.findBackendForQueryId(queryId)).isEqualTo(backend); + + queryId = "custom_path_id"; + responseBody = String.format("{\"id\":\"%s\"}", queryId); + buffer = responseBody.getBytes(UTF_8); + when(request.getRequestURI()).thenReturn(customStatementPath); + queryIdCachingProxyHandler.postConnectionHook(request, response, buffer, 0, buffer.length, NOOP); + assertThat(routingManager.findBackendForQueryId(queryId)).isEqualTo(backend); + + queryId = "invalid_path_id"; + responseBody = String.format("{\"id\":\"%s\"}", queryId); + buffer = responseBody.getBytes(UTF_8); + when(request.getRequestURI()).thenReturn("/v1/invalid/statement/path"); + queryIdCachingProxyHandler.postConnectionHook(request, response, buffer, 0, buffer.length, NOOP); + assertThat(routingManager.findBackendForQueryId(queryId)).isNull(); + } + @Test public void testUserFromRequest() throws IOException @@ -84,12 +200,12 @@ public void testUserFromRequest() HttpServletRequest req = Mockito.mock(HttpServletRequest.class); String authHeader = "Basic dGVzdDoxMjPCow=="; - Mockito.when(req.getHeader(QueryIdCachingProxyHandler.AUTHORIZATION)) + when(req.getHeader(QueryIdCachingProxyHandler.AUTHORIZATION)) .thenReturn(authHeader); assertThat(QueryIdCachingProxyHandler.getQueryUser(req)).isEqualTo("test"); String user = "trino_user"; - Mockito.when(req.getHeader(QueryIdCachingProxyHandler.USER_HEADER)) + when(req.getHeader(QueryIdCachingProxyHandler.USER_HEADER)) .thenReturn(user); assertThat(QueryIdCachingProxyHandler.getQueryUser(req)).isEqualTo(user); } diff --git a/gateway-ha/src/test/java/io/trino/gateway/ha/router/TestingProxyHandlerStats.java b/gateway-ha/src/test/java/io/trino/gateway/ha/router/TestingProxyHandlerStats.java new file mode 100644 index 000000000..bd9609ce6 --- /dev/null +++ b/gateway-ha/src/test/java/io/trino/gateway/ha/router/TestingProxyHandlerStats.java @@ -0,0 +1,35 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.gateway.ha.router; + +import io.airlift.stats.CounterStat; +import io.trino.gateway.ha.handler.ProxyHandlerStats; + +public class TestingProxyHandlerStats + implements ProxyHandlerStats +{ + private final CounterStat requestCount = new CounterStat(); + + @Override + public void recordRequest() + { + requestCount.update(1); + } + + @Override + public CounterStat getRequestCount() + { + return requestCount; + } +} diff --git a/gateway-ha/src/test/java/io/trino/gateway/ha/router/TestingQueryManager.java b/gateway-ha/src/test/java/io/trino/gateway/ha/router/TestingQueryManager.java new file mode 100644 index 000000000..6278f98ff --- /dev/null +++ b/gateway-ha/src/test/java/io/trino/gateway/ha/router/TestingQueryManager.java @@ -0,0 +1,64 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.gateway.ha.router; + +import com.google.common.collect.ImmutableList; +import io.trino.gateway.ha.domain.TableData; +import io.trino.gateway.ha.domain.request.QueryHistoryRequest; +import io.trino.gateway.ha.domain.response.DistributionResponse; + +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Optional; + +public class TestingQueryManager + implements QueryHistoryManager +{ + Map queryDetailMap = new HashMap<>(); + + public TestingQueryManager() + { + } + + @Override + public void submitQueryDetail(QueryDetail queryDetail) + { + queryDetailMap.put(queryDetail.getQueryId(), queryDetail); + } + + @Override + public List fetchQueryHistory(Optional user) + { + return queryDetailMap.values().stream().toList(); + } + + @Override + public String getBackendForQueryId(String queryId) + { + return queryDetailMap.get(queryId).getBackendUrl(); + } + + @Override + public TableData findQueryHistory(QueryHistoryRequest query) + { + return TableData.build(ImmutableList.of(queryDetailMap.get(query.queryId())), 1); + } + + @Override + public List findDistribution(Long ts) + { + throw new UnsupportedOperationException("Not implemented"); + } +} diff --git a/gateway-ha/src/test/java/io/trino/gateway/ha/router/TestingRoutingManager.java b/gateway-ha/src/test/java/io/trino/gateway/ha/router/TestingRoutingManager.java new file mode 100644 index 000000000..aa514fe77 --- /dev/null +++ b/gateway-ha/src/test/java/io/trino/gateway/ha/router/TestingRoutingManager.java @@ -0,0 +1,70 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.gateway.ha.router; + +import io.trino.gateway.ha.clustermonitor.ClusterStats; + +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +public class TestingRoutingManager + implements RoutingManager +{ + private String adhocBackend; + private Map queryIdBackendMap = new HashMap<>(); + + public TestingRoutingManager(String adhocBackend) + { + this.adhocBackend = adhocBackend; + } + + public void setBackendForQueryId(String queryId, String backend) + { + queryIdBackendMap.put(queryId, backend); + } + + public String provideAdhocBackend(String user) + { + return adhocBackend; + } + + @Override + public String provideBackendForRoutingGroup(String routingGroup, String user) + { + return adhocBackend; + } + + @Override + public String findBackendForQueryId(String queryId) + { + return queryIdBackendMap.get(queryId); + } + + @Override + public void updateBackEndHealth(String backendId, Boolean value) + { + } + + @Override + public void updateBackEndStats(List stats) + { + } + + @Override + public String findBackendForUnknownQueryId(String queryId) + { + return adhocBackend; + } +}