Skip to content

Commit

Permalink
Add option for custom v1/statement-like paths
Browse files Browse the repository at this point in the history
  • Loading branch information
willmostly committed Apr 26, 2024
1 parent a6ea834 commit 2d93a73
Show file tree
Hide file tree
Showing 12 changed files with 625 additions and 242 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ public class HaGatewayConfiguration
private List<String> extraWhitelistPaths = new ArrayList<>();
private OAuth2GatewayCookieConfiguration oauth2GatewayCookieConfiguration = new OAuth2GatewayCookieConfiguration();
private GatewayCookieConfiguration gatewayCookieConfiguration = new GatewayCookieConfiguration();
private List<String> customStatementPaths = new ArrayList<>();

// List of Modules with FQCN (Fully Qualified Class Name)
private List<String> modules;
Expand Down Expand Up @@ -205,4 +206,14 @@ public void setManagedApps(List<String> managedApps)
{
this.managedApps = managedApps;
}

public List<String> getCustomStatementPaths()
{
return customStatementPaths;
}

public void setCustomStatementPaths(List<String> customStatementPaths)
{
this.customStatementPaths = customStatementPaths;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.trino.gateway.ha.handler;

import com.codahale.metrics.Meter;
import io.airlift.stats.CounterStat;
import io.dropwizard.core.setup.Environment;
import org.weakref.jmx.MBeanExporter;
import org.weakref.jmx.Managed;
import org.weakref.jmx.Nested;

import java.lang.management.ManagementFactory;

import static java.util.Objects.requireNonNull;

public final class DropWizardProxyHandlerStats
implements ProxyHandlerStats
{
// Airlift
private final CounterStat requestCount = new CounterStat();

// Dropwizard
private final Meter requestMeter;

private DropWizardProxyHandlerStats(Environment environment, String metricsName)
{
this.requestMeter = requireNonNull(environment, "environment is null")
.metrics()
.meter(requireNonNull(metricsName, "metricsName is null"));
}

@Override
public void recordRequest()
{
requestCount.update(1);
requestMeter.mark();
}

// Replace this with Guice bind after migrated to Airlift
public static ProxyHandlerStats create(Environment environment, String metricsName)
{
ProxyHandlerStats proxyHandlerStats = new DropWizardProxyHandlerStats(environment, metricsName);
MBeanExporter exporter = new MBeanExporter(ManagementFactory.getPlatformMBeanServer());
exporter.exportWithGeneratedName(proxyHandlerStats);
return proxyHandlerStats;
}

@Override
@Managed
@Nested
public CounterStat getRequestCount()
{
return requestCount;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -13,51 +13,15 @@
*/
package io.trino.gateway.ha.handler;

import com.codahale.metrics.Meter;
import io.airlift.stats.CounterStat;
import io.dropwizard.core.setup.Environment;
import org.weakref.jmx.MBeanExporter;
import org.weakref.jmx.Managed;
import org.weakref.jmx.Nested;

import java.lang.management.ManagementFactory;

import static java.util.Objects.requireNonNull;

public final class ProxyHandlerStats
public interface ProxyHandlerStats
{
// Airlift
private final CounterStat requestCount = new CounterStat();

// Dropwizard
private final Meter requestMeter;

private ProxyHandlerStats(Environment environment, String metricsName)
{
this.requestMeter = requireNonNull(environment, "environment is null")
.metrics()
.meter(requireNonNull(metricsName, "metricsName is null"));
}

public void recordRequest()
{
requestCount.update(1);
requestMeter.mark();
}

// Replace this with Guice bind after migrated to Airlift
public static ProxyHandlerStats create(Environment environment, String metricsName)
{
ProxyHandlerStats proxyHandlerStats = new ProxyHandlerStats(environment, metricsName);
MBeanExporter exporter = new MBeanExporter(ManagementFactory.getPlatformMBeanServer());
exporter.exportWithGeneratedName(proxyHandlerStats);
return proxyHandlerStats;
}
void recordRequest();

@Managed
@Nested
public CounterStat getRequestCount()
{
return requestCount;
}
CounterStat getRequestCount();
}
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
import java.util.Optional;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
import java.util.stream.Stream;

import static com.google.common.base.Strings.isNullOrEmpty;

Expand Down Expand Up @@ -90,6 +91,7 @@ public class QueryIdCachingProxyHandler

private final ProxyHandlerStats proxyHandlerStats;
private final List<String> extraWhitelistPaths;
private final List<String> customStatementPaths;
private final String applicationEndpoint;
private final boolean cookiesEnabled;

Expand All @@ -99,36 +101,41 @@ public QueryIdCachingProxyHandler(
RoutingGroupSelector routingGroupSelector,
int serverApplicationPort,
ProxyHandlerStats proxyHandlerStats,
List<String> extraWhitelistPaths)
List<String> extraWhitelistPaths,
List<String> customStatementPaths)
{
this.proxyHandlerStats = proxyHandlerStats;
this.routingManager = routingManager;
this.routingGroupSelector = routingGroupSelector;
this.queryHistoryManager = queryHistoryManager;
this.extraWhitelistPaths = extraWhitelistPaths;
this.customStatementPaths = customStatementPaths;
this.applicationEndpoint = "http://localhost:" + serverApplicationPort;
cookiesEnabled = GatewayCookieConfigurationPropertiesProvider.getInstance().isEnabled();
}

protected static String extractQueryIdIfPresent(String path, String queryParams)
protected String extractQueryIdIfPresent(String path, String queryParams)
{
if (path == null) {
return null;
}
String queryId = null;

log.debug("trying to extract query id from path [%s] or queryString [%s]", path, queryParams);
if (path.startsWith(V1_STATEMENT_PATH) || path.startsWith(V1_QUERY_PATH)) {
String[] tokens = path.split("/");
if (tokens.length >= 4) {
Optional<String> pathPrefix = Stream.concat(Stream.of(V1_STATEMENT_PATH, V1_QUERY_PATH), customStatementPaths.stream()).filter(path::startsWith).findFirst();
if (pathPrefix.isPresent()) {
String reducedPath = path.replace(pathPrefix.orElseThrow(), "");

String[] tokens = reducedPath.split("/");
if (tokens.length >= 2) {
if (path.contains("queued")
|| path.contains("scheduled")
|| path.contains("executing")
|| path.contains("partialCancel")) {
queryId = tokens[4];
queryId = tokens[2];
}
else {
queryId = tokens[3];
queryId = tokens[1];
}
}
}
Expand Down Expand Up @@ -246,7 +253,8 @@ protected String extractQueryIdIfPresent(HttpServletRequest request)
public void preConnectionHook(HttpServletRequest request, Request proxyRequest)
{
if (request.getMethod().equals(HttpMethod.POST)
&& request.getRequestURI().startsWith(V1_STATEMENT_PATH)) {
&& (request.getRequestURI().startsWith(V1_STATEMENT_PATH)
|| customStatementPaths.stream().anyMatch(request.getRequestURI()::startsWith))) {
proxyHandlerStats.recordRequest();
try {
String requestBody = CharStreams.toString(request.getReader());
Expand Down Expand Up @@ -275,7 +283,8 @@ private boolean isPathWhiteListed(String path)
|| path.startsWith(V1_NODE_PATH)
|| path.startsWith(UI_API_STATS_PATH)
|| path.startsWith(OAUTH_PATH)
|| extraWhitelistPaths.stream().anyMatch(s -> path.startsWith(s));
|| extraWhitelistPaths.stream().anyMatch(path::startsWith)
|| customStatementPaths.stream().anyMatch(path::startsWith);
}

@Override
Expand Down Expand Up @@ -378,10 +387,11 @@ protected void postConnectionHook(
Callback callback)
{
try {
if (request.getRequestURI().startsWith(V1_STATEMENT_PATH) && request.getMethod().equals(HttpMethod.POST)) {
String requestPath = request.getRequestURI();
if ((requestPath.startsWith(V1_STATEMENT_PATH) || customStatementPaths.stream().anyMatch(requestPath::startsWith)) && request.getMethod().equals(HttpMethod.POST)) {
recordBackendForQueryId(request, response, buffer);
}
else if (cookiesEnabled && request.getRequestURI().startsWith(OAuth2GatewayCookie.OAUTH2_PATH)
else if (cookiesEnabled && requestPath.startsWith(OAuth2GatewayCookie.OAUTH2_PATH)
&& !(request.getCookies() != null
&& Arrays.stream(request.getCookies()).anyMatch(c -> c.getName().equals(OAuth2GatewayCookie.NAME)))) {
GatewayCookie oauth2Cookie = new OAuth2GatewayCookie(request.getHeader(PROXY_TARGET_HEADER));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
import io.trino.gateway.ha.config.RequestRouterConfiguration;
import io.trino.gateway.ha.config.RoutingRulesConfiguration;
import io.trino.gateway.ha.config.UserConfiguration;
import io.trino.gateway.ha.handler.DropWizardProxyHandlerStats;
import io.trino.gateway.ha.handler.ProxyHandlerStats;
import io.trino.gateway.ha.handler.QueryIdCachingProxyHandler;
import io.trino.gateway.ha.router.BackendStateManager;
Expand Down Expand Up @@ -72,6 +73,7 @@ public class HaGatewayProviderModule
private final BackendStateManager backendStateConnectionManager;
private final AuthFilter authenticationFilter;
private final List<String> extraWhitelistPaths;
private final List<String> customStatementPaths;
private final HaGatewayConfiguration configuration;
private final Environment environment;

Expand All @@ -94,6 +96,7 @@ public HaGatewayProviderModule(HaGatewayConfiguration configuration, Environment

OAuth2GatewayCookieConfigurationPropertiesProvider oAuth2GatewayCookieConfigurationPropertiesProvider = OAuth2GatewayCookieConfigurationPropertiesProvider.getInstance();
oAuth2GatewayCookieConfigurationPropertiesProvider.initialize(configuration.getOauth2GatewayCookieConfiguration());
customStatementPaths = configuration.getCustomStatementPaths();
}

private LbOAuthManager getOAuthManager(HaGatewayConfiguration configuration)
Expand Down Expand Up @@ -149,9 +152,9 @@ private ChainedAuthFilter getAuthenticationFilters(AuthenticationConfiguration c
}

private ProxyHandler getProxyHandler(QueryHistoryManager queryHistoryManager,
RoutingManager routingManager)
RoutingManager routingManager)
{
ProxyHandlerStats proxyHandlerStats = ProxyHandlerStats.create(
ProxyHandlerStats proxyHandlerStats = DropWizardProxyHandlerStats.create(
environment,
configuration.getRequestRouter().getName() + ".requests");

Expand All @@ -170,7 +173,8 @@ private ProxyHandler getProxyHandler(QueryHistoryManager queryHistoryManager,
routingGroupSelector,
getApplicationPort(),
proxyHandlerStats,
extraWhitelistPaths);
extraWhitelistPaths,
customStatementPaths);
}

private int getApplicationPort()
Expand Down Expand Up @@ -211,7 +215,7 @@ private AuthFilter getAuthFilter(HaGatewayConfiguration configuration)
@Provides
@Singleton
public ProxyServer provideGateway(QueryHistoryManager queryHistoryManager,
RoutingManager routingManager)
RoutingManager routingManager)
{
ProxyServer gateway = null;
if (configuration.getRequestRouter() != null) {
Expand Down

0 comments on commit 2d93a73

Please sign in to comment.