Skip to content

Commit

Permalink
Process Request for routing rules
Browse files Browse the repository at this point in the history
  • Loading branch information
willmostly committed May 23, 2024
1 parent c052bb2 commit 8f2e510
Show file tree
Hide file tree
Showing 11 changed files with 1,136 additions and 15 deletions.
19 changes: 19 additions & 0 deletions gateway-ha/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,12 @@
</exclusions>
</dependency>

<dependency>
<groupId>io.airlift</groupId>
<artifactId>aircompressor</artifactId>
<version>0.26</version>
</dependency>

<dependency>
<groupId>io.airlift</groupId>
<artifactId>configuration</artifactId>
Expand Down Expand Up @@ -240,6 +246,12 @@
<artifactId>metrics-core</artifactId>
</dependency>

<dependency>
<groupId>io.trino</groupId>
<artifactId>trino-parser</artifactId>
<version>440</version>
</dependency>

<dependency>
<groupId>jakarta.annotation</groupId>
<artifactId>jakarta.annotation-api</artifactId>
Expand Down Expand Up @@ -436,6 +448,13 @@
<scope>test</scope>
</dependency>

<dependency>
<groupId>org.gaul</groupId>
<artifactId>modernizer-maven-annotations</artifactId>
<version>2.7.0</version>
<scope>test</scope>
</dependency>

<dependency>
<groupId>org.junit.jupiter</groupId>
<artifactId>junit-jupiter-api</artifactId>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@ public class HaGatewayConfiguration
private OAuth2GatewayCookieConfiguration oauth2GatewayCookieConfiguration = new OAuth2GatewayCookieConfiguration();
private GatewayCookieConfiguration gatewayCookieConfiguration = new GatewayCookieConfiguration();

private RequestAnalyzerConfig processedRequestConfig = new RequestAnalyzerConfig();

// List of Modules with FQCN (Fully Qualified Class Name)
private List<String> modules;

Expand Down Expand Up @@ -186,6 +188,16 @@ public void setGatewayCookieConfiguration(GatewayCookieConfiguration gatewayCook
this.gatewayCookieConfiguration = gatewayCookieConfiguration;
}

public RequestAnalyzerConfig getProcessedRequestConfig()
{
return processedRequestConfig;
}

public void setProcessedRequestConfig(RequestAnalyzerConfig requestAnalyzerConfig)
{
this.processedRequestConfig = requestAnalyzerConfig;
}

public List<String> getModules()
{
return this.modules;
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.trino.gateway.ha.config;

public class RequestAnalyzerConfig
{
private Integer maxBodySize = 1_000_000;

private boolean isClientsUseV2Format;
private String tokenUserField = "email";
private String oauthTokenInfoUrl;
private boolean isGetUserFromOauth;

private boolean isAnalyzeRequest;

public RequestAnalyzerConfig() {}

public Integer getMaxBodySize()
{
return maxBodySize;
}

public void setMaxBodySize(Integer maxBodySize)
{
this.maxBodySize = maxBodySize;
}

public String getTokenUserField()
{
return tokenUserField;
}

public void setTokenUserField(String tokenUserField)
{
this.tokenUserField = tokenUserField;
}

public String getOauthTokenInfoUrl()
{
return oauthTokenInfoUrl;
//Google: "https://oauth2.googleapis.com/tokeninfo"
}

public void setOauthTokenInfoUrl(String oauthTokenInfoUrl)
{
this.oauthTokenInfoUrl = oauthTokenInfoUrl;
}

public boolean isGetUserFromOauth()
{
return isGetUserFromOauth;
}

public void setGetUserFromOauth(boolean getUserFromOauth)
{
isGetUserFromOauth = getUserFromOauth;
}

public boolean isClientsUseV2Format()
{
return isClientsUseV2Format;
}

public void setClientsUseV2Format(boolean clientsUseV2Format)
{
isClientsUseV2Format = clientsUseV2Format;
}

public boolean isAnalyzeRequest()
{
return isAnalyzeRequest;
}

public void setAnalyzeRequest(boolean analyzeRequest)
{
isAnalyzeRequest = analyzeRequest;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
import io.trino.gateway.ha.config.GatewayCookieConfigurationPropertiesProvider;
import io.trino.gateway.ha.config.HaGatewayConfiguration;
import io.trino.gateway.ha.config.OAuth2GatewayCookieConfigurationPropertiesProvider;
import io.trino.gateway.ha.config.RequestAnalyzerConfig;
import io.trino.gateway.ha.config.RequestRouterConfiguration;
import io.trino.gateway.ha.config.RoutingRulesConfiguration;
import io.trino.gateway.ha.config.UserConfiguration;
Expand Down Expand Up @@ -149,7 +150,7 @@ private ChainedAuthFilter getAuthenticationFilters(AuthenticationConfiguration c
}

private ProxyHandler getProxyHandler(QueryHistoryManager queryHistoryManager,
RoutingManager routingManager)
RoutingManager routingManager)
{
ProxyHandlerStats proxyHandlerStats = ProxyHandlerStats.create(
environment,
Expand All @@ -159,9 +160,10 @@ private ProxyHandler getProxyHandler(QueryHistoryManager queryHistoryManager,
RoutingGroupSelector routingGroupSelector = RoutingGroupSelector.byRoutingGroupHeader();
// Use rules engine if enabled
RoutingRulesConfiguration routingRulesConfig = configuration.getRoutingRules();
RequestAnalyzerConfig requestAnalyzerConfig = configuration.getProcessedRequestConfig();
if (routingRulesConfig.isRulesEngineEnabled()) {
String rulesConfigPath = routingRulesConfig.getRulesConfigPath();
routingGroupSelector = RoutingGroupSelector.byRoutingRulesEngine(rulesConfigPath);
routingGroupSelector = RoutingGroupSelector.byRoutingRulesEngine(rulesConfigPath, requestAnalyzerConfig);
}

return new QueryIdCachingProxyHandler(
Expand Down Expand Up @@ -211,7 +213,7 @@ private AuthFilter getAuthFilter(HaGatewayConfiguration configuration)
@Provides
@Singleton
public ProxyServer provideGateway(QueryHistoryManager queryHistoryManager,
RoutingManager routingManager)
RoutingManager routingManager)
{
ProxyServer gateway = null;
if (configuration.getRequestRouter() != null) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
*/
package io.trino.gateway.ha.router;

import io.trino.gateway.ha.config.RequestAnalyzerConfig;
import jakarta.servlet.http.HttpServletRequest;

/**
Expand All @@ -35,9 +36,9 @@ static RoutingGroupSelector byRoutingGroupHeader()
* Routing group selector that uses routing engine rules
* to determine the right routing group.
*/
static RoutingGroupSelector byRoutingRulesEngine(String rulesConfigPath)
static RoutingGroupSelector byRoutingRulesEngine(String rulesConfigPath, RequestAnalyzerConfig requestAnalyzerConfig)
{
return new RuleReloadingRoutingGroupSelector(rulesConfigPath);
return new RuleReloadingRoutingGroupSelector(rulesConfigPath, requestAnalyzerConfig);
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
package io.trino.gateway.ha.router;

import io.airlift.log.Logger;
import io.trino.gateway.ha.config.RequestAnalyzerConfig;
import jakarta.servlet.http.HttpServletRequest;
import org.jeasy.rules.api.Facts;
import org.jeasy.rules.api.Rules;
Expand Down Expand Up @@ -43,10 +44,12 @@ public class RuleReloadingRoutingGroupSelector
private volatile Rules rules = new Rules();
private volatile long lastUpdatedTime;
private final ReadWriteLock readWriteLock = new ReentrantReadWriteLock(true);
private final RequestAnalyzerConfig requestAnalyzerConfig;

RuleReloadingRoutingGroupSelector(String rulesConfigPath)
RuleReloadingRoutingGroupSelector(String rulesConfigPath, RequestAnalyzerConfig requestAnalyzerConfig)
{
this.rulesConfigPath = rulesConfigPath;
this.requestAnalyzerConfig = requestAnalyzerConfig;
try {
rules = ruleFactory.createRules(
new FileReader(rulesConfigPath, UTF_8));
Expand Down Expand Up @@ -84,9 +87,17 @@ public String findRoutingGroup(HttpServletRequest request)
writeLock.unlock();
}
}

Facts facts = new Facts();
HashMap<String, String> result = new HashMap<String, String>();

facts.put("request", request);
if (requestAnalyzerConfig.isAnalyzeRequest()) {
TrinoQueryProperties trinoQueryProperties = new TrinoQueryProperties(request, requestAnalyzerConfig);
TrinoRequestUser trinoRequestUser = new TrinoRequestUser(request, requestAnalyzerConfig);
facts.put("trinoQueryProperties", trinoQueryProperties);
facts.put("trinoRequestUser", trinoRequestUser);
}
facts.put("result", result);
Lock readLock = readWriteLock.readLock();
readLock.lock();
Expand Down
Loading

0 comments on commit 8f2e510

Please sign in to comment.