Skip to content

Commit 5e3bafb

Browse files
authored
feat:Support kyuubi presto trino (#2109)
1 parent 11ff99c commit 5e3bafb

File tree

31 files changed

+499
-99
lines changed

31 files changed

+499
-99
lines changed

common/src/main/java/com/tencent/supersonic/common/pojo/enums/EngineType.java

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,10 @@ public enum EngineType {
1010
OTHER(7, "OTHER"),
1111
DUCKDB(8, "DUCKDB"),
1212
HANADB(9, "HANADB"),
13-
STARROCKS(10, "STARROCKS"),;
13+
STARROCKS(10, "STARROCKS"),
14+
KYUUBI(11, "KYUUBI"),
15+
PRESTO(12, "PRESTO"),
16+
TRINO(13, "TRINO"),;
1417

1518
private Integer code;
1619

headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/DbSchema.java

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@
1111
@NoArgsConstructor
1212
public class DbSchema {
1313

14+
private String catalog;
15+
1416
private String db;
1517

1618
private String table;

headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/enums/DataType.java

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,9 @@
88
public enum DataType {
99
MYSQL("mysql", "mysql", "com.mysql.cj.jdbc.Driver", "`", "`", "'", "'"),
1010

11-
HIVE2("hive2", "hive", "org.apache.hive.jdbc.HiveDriver", "`", "`", "`", "`"),
11+
HIVE2("hive2", "hive", "org.apache.kyuubi.jdbc.KyuubiHiveDriver", "`", "`", "`", "`"),
12+
13+
KYUUBI("kyuubi", "kyuubi", "org.apache.kyuubi.jdbc.KyuubiHiveDriver", "`", "`", "`", "`"),
1214

1315
ORACLE("oracle", "oracle", "oracle.jdbc.driver.OracleDriver", "\"", "\"", "\"", "\""),
1416

@@ -27,6 +29,8 @@ public enum DataType {
2729

2830
PRESTO("presto", "presto", "com.facebook.presto.jdbc.PrestoDriver", "\"", "\"", "\"", "\""),
2931

32+
TRINO("trino", "trino", "io.trino.jdbc.TrinoDriver", "\"", "\"", "\"", "\""),
33+
3034
MOONBOX("moonbox", "moonbox", "moonbox.jdbc.MbDriver", "`", "`", "`", "`"),
3135

3236
CASSANDRA("cassandra", "cassandra", "com.github.adejanovski.cassandra.jdbc.CassandraDriver", "",
@@ -46,6 +50,7 @@ public enum DataType {
4650
TDENGINE("TAOS", "TAOS", "com.taosdata.jdbc.TSDBDriver", "'", "'", "\"", "\""),
4751

4852
POSTGRESQL("postgresql", "postgresql", "org.postgresql.Driver", "'", "'", "\"", "\""),
53+
4954
DUCKDB("duckdb", "duckdb", "org.duckdb.DuckDBDriver", "'", "'", "\"", "\"");
5055

5156
private String feature;

headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/request/ModelBuildReq.java

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@ public class ModelBuildReq {
1919

2020
private String sql;
2121

22+
private String catalog;
23+
2224
private String db;
2325

2426
private List<String> tables;

headless/chat/src/main/java/com/tencent/supersonic/headless/chat/mapper/EmbeddingMatchStrategy.java

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,8 @@ public List<EmbeddingResult> detect(ChatQueryContext chatQueryContext, List<S2Te
7272
// 1. Base detection
7373
List<EmbeddingResult> baseResults = super.detect(chatQueryContext, terms, detectDataSetIds);
7474

75-
boolean useLLM = Boolean.parseBoolean(mapperConfig.getParameterValue(EMBEDDING_MAPPER_USE_LLM));
75+
boolean useLLM =
76+
Boolean.parseBoolean(mapperConfig.getParameterValue(EMBEDDING_MAPPER_USE_LLM));
7677

7778
// 2. LLM enhanced detection
7879
if (useLLM) {
@@ -115,7 +116,8 @@ private List<EmbeddingResult> detectWithLLM(ChatQueryContext chatQueryContext,
115116
* Extract valid word segments by filtering out unwanted word natures
116117
*/
117118
private Set<String> extractValidSegments(String text) {
118-
List<String> natureList = Arrays.asList(StringUtils.split(mapperConfig.getParameterValue(EMBEDDING_MAPPER_ALLOWED_SEGMENT_NATURE ), ","));
119+
List<String> natureList = Arrays.asList(StringUtils.split(
120+
mapperConfig.getParameterValue(EMBEDDING_MAPPER_ALLOWED_SEGMENT_NATURE), ","));
119121
return HanlpHelper.getSegment().seg(text).stream()
120122
.filter(t -> natureList.stream().noneMatch(nature -> t.nature.startsWith(nature)))
121123
.map(Term::getWord).collect(Collectors.toSet());

headless/chat/src/main/java/com/tencent/supersonic/headless/chat/mapper/MapFilter.java

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,8 @@ public static void filterByDetectWordLenLessThanOne(ChatQueryContext chatQueryCo
6161
List<SchemaElementMatch> value = entry.getValue();
6262
if (!CollectionUtils.isEmpty(value)) {
6363
value.removeIf(schemaElementMatch -> StringUtils
64-
.length(schemaElementMatch.getDetectWord()) <= 1 && !schemaElementMatch.isLlmMatched());
64+
.length(schemaElementMatch.getDetectWord()) <= 1
65+
&& !schemaElementMatch.isLlmMatched());
6566
}
6667
}
6768
}
@@ -80,7 +81,7 @@ private static void twoCharactersMustEqual(ChatQueryContext chatQueryContext) {
8081
}
8182

8283
public static void filterByQueryDataType(ChatQueryContext chatQueryContext,
83-
Predicate<SchemaElement> needRemovePredicate) {
84+
Predicate<SchemaElement> needRemovePredicate) {
8485
Map<Long, List<SchemaElementMatch>> dataSetElementMatches =
8586
chatQueryContext.getMapInfo().getDataSetElementMatches();
8687
for (Map.Entry<Long, List<SchemaElementMatch>> entry : dataSetElementMatches.entrySet()) {

headless/chat/src/main/java/com/tencent/supersonic/headless/chat/mapper/MapperConfig.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,6 @@ public class MapperConfig extends ParameterConfig {
6363
"embedding的结果再通过一次LLM来筛选,这时候忽略各个向量阀值", "bool", "Mapper相关配置");
6464

6565
public static final Parameter EMBEDDING_MAPPER_ALLOWED_SEGMENT_NATURE =
66-
new Parameter("s2.mapper.embedding.allowed-segment-nature", "['v', 'd', 'a']", "使用LLM召回二次处理时对问题分词词性的控制",
67-
"分词后允许的词性才会进行向量召回", "list", "Mapper相关配置");
66+
new Parameter("s2.mapper.embedding.allowed-segment-nature", "['v', 'd', 'a']",
67+
"使用LLM召回二次处理时对问题分词词性的控制", "分词后允许的词性才会进行向量召回", "list", "Mapper相关配置");
6868
}

headless/core/pom.xml

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -121,6 +121,18 @@
121121
<artifactId>DmJdbcDriver18</artifactId>
122122
<version>8.1.2.192</version>
123123
</dependency>
124+
<dependency>
125+
<groupId>org.apache.kyuubi</groupId>
126+
<artifactId>kyuubi-hive-jdbc</artifactId>
127+
</dependency>
128+
<dependency>
129+
<groupId>com.facebook.presto</groupId>
130+
<artifactId>presto-jdbc</artifactId>
131+
</dependency>
132+
<dependency>
133+
<groupId>io.trino</groupId>
134+
<artifactId>trino-jdbc</artifactId>
135+
</dependency>
124136
</dependencies>
125137

126138

headless/core/src/main/java/com/tencent/supersonic/headless/core/adaptor/db/BaseDbAdaptor.java

Lines changed: 79 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -5,23 +5,27 @@
55
import com.tencent.supersonic.headless.api.pojo.enums.FieldType;
66
import com.tencent.supersonic.headless.core.pojo.ConnectInfo;
77
import lombok.extern.slf4j.Slf4j;
8+
import org.apache.commons.lang3.StringUtils;
89

9-
import java.sql.Connection;
10-
import java.sql.DatabaseMetaData;
11-
import java.sql.DriverManager;
12-
import java.sql.ResultSet;
13-
import java.sql.SQLException;
10+
import java.sql.*;
1411
import java.util.ArrayList;
1512
import java.util.List;
13+
import java.util.Properties;
1614

1715
@Slf4j
1816
public abstract class BaseDbAdaptor implements DbAdaptor {
1917

2018
@Override
2119
public List<String> getCatalogs(ConnectInfo connectInfo) throws SQLException {
22-
// Apart from supporting multiple catalog types of data sources, other types will return an
23-
// empty set by default.
24-
return List.of();
20+
List<String> catalogs = Lists.newArrayList();
21+
try (Connection con = getConnection(connectInfo);
22+
Statement st = con.createStatement();
23+
ResultSet rs = st.executeQuery("SHOW CATALOGS")) {
24+
while (rs.next()) {
25+
catalogs.add(rs.getString(1));
26+
}
27+
}
28+
return catalogs;
2529
}
2630

2731
public List<String> getDBs(ConnectInfo connectionInfo, String catalog) throws SQLException {
@@ -32,38 +36,49 @@ public List<String> getDBs(ConnectInfo connectionInfo, String catalog) throws SQ
3236

3337
protected List<String> getDBs(ConnectInfo connectionInfo) throws SQLException {
3438
List<String> dbs = Lists.newArrayList();
35-
DatabaseMetaData metaData = getDatabaseMetaData(connectionInfo);
3639
try {
37-
ResultSet schemaSet = metaData.getSchemas();
38-
while (schemaSet.next()) {
39-
String db = schemaSet.getString("TABLE_SCHEM");
40-
dbs.add(db);
40+
try (ResultSet schemaSet = getDatabaseMetaData(connectionInfo).getSchemas()) {
41+
while (schemaSet.next()) {
42+
String db = schemaSet.getString("TABLE_SCHEM");
43+
dbs.add(db);
44+
}
4145
}
4246
} catch (Exception e) {
43-
log.info("get meta schemas failed, try to get catalogs");
47+
log.warn("get meta schemas failed", e);
48+
log.warn("get meta schemas failed, try to get catalogs");
4449
}
4550
try {
46-
ResultSet catalogSet = metaData.getCatalogs();
47-
while (catalogSet.next()) {
48-
String db = catalogSet.getString("TABLE_CAT");
49-
dbs.add(db);
51+
try (ResultSet catalogSet = getDatabaseMetaData(connectionInfo).getCatalogs()) {
52+
while (catalogSet.next()) {
53+
String db = catalogSet.getString("TABLE_CAT");
54+
dbs.add(db);
55+
}
5056
}
5157
} catch (Exception e) {
52-
log.info("get meta catalogs failed, try to get schemas");
58+
log.warn("get meta catalogs failed", e);
59+
log.warn("get meta catalogs failed, try to get schemas");
5360
}
5461
return dbs;
5562
}
5663

57-
public List<String> getTables(ConnectInfo connectionInfo, String schemaName)
64+
@Override
65+
public List<String> getTables(ConnectInfo connectInfo, String catalog, String schemaName)
66+
throws SQLException {
67+
// Except for special types implemented separately, the generic logic catalog does not take
68+
// effect.
69+
return getTables(connectInfo, schemaName);
70+
}
71+
72+
protected List<String> getTables(ConnectInfo connectionInfo, String schemaName)
5873
throws SQLException {
5974
List<String> tablesAndViews = new ArrayList<>();
60-
DatabaseMetaData metaData = getDatabaseMetaData(connectionInfo);
6175

6276
try {
63-
ResultSet resultSet = getResultSet(schemaName, metaData);
64-
while (resultSet.next()) {
65-
String name = resultSet.getString("TABLE_NAME");
66-
tablesAndViews.add(name);
77+
try(ResultSet resultSet = getResultSet(schemaName, getDatabaseMetaData(connectionInfo))) {
78+
while (resultSet.next()) {
79+
String name = resultSet.getString("TABLE_NAME");
80+
tablesAndViews.add(name);
81+
}
6782
}
6883
} catch (SQLException e) {
6984
log.error("Failed to get tables and views", e);
@@ -76,27 +91,34 @@ protected ResultSet getResultSet(String schemaName, DatabaseMetaData metaData)
7691
return metaData.getTables(schemaName, schemaName, null, new String[] {"TABLE", "VIEW"});
7792
}
7893

79-
public List<DBColumn> getColumns(ConnectInfo connectInfo, String schemaName, String tableName)
94+
95+
96+
public List<DBColumn> getColumns(ConnectInfo connectInfo, String catalog, String schemaName, String tableName)
8097
throws SQLException {
81-
List<DBColumn> dbColumns = Lists.newArrayList();
82-
DatabaseMetaData metaData = getDatabaseMetaData(connectInfo);
83-
ResultSet columns = metaData.getColumns(schemaName, schemaName, tableName, null);
84-
while (columns.next()) {
85-
String columnName = columns.getString("COLUMN_NAME");
86-
String dataType = columns.getString("TYPE_NAME");
87-
String remarks = columns.getString("REMARKS");
88-
FieldType fieldType = classifyColumnType(dataType);
89-
dbColumns.add(new DBColumn(columnName, dataType, remarks, fieldType));
98+
List<DBColumn> dbColumns = new ArrayList<>();
99+
// 确保连接会自动关闭
100+
try (ResultSet columns = getDatabaseMetaData(connectInfo).getColumns(catalog, schemaName, tableName, null)) {
101+
while (columns.next()) {
102+
String columnName = columns.getString("COLUMN_NAME");
103+
String dataType = columns.getString("TYPE_NAME");
104+
String remarks = columns.getString("REMARKS");
105+
FieldType fieldType = classifyColumnType(dataType);
106+
dbColumns.add(new DBColumn(columnName, dataType, remarks, fieldType));
107+
}
90108
}
91109
return dbColumns;
92110
}
93111

94112
protected DatabaseMetaData getDatabaseMetaData(ConnectInfo connectionInfo) throws SQLException {
95-
Connection connection = DriverManager.getConnection(connectionInfo.getUrl(),
96-
connectionInfo.getUserName(), connectionInfo.getPassword());
113+
Connection connection = getConnection(connectionInfo);
97114
return connection.getMetaData();
98115
}
99116

117+
public Connection getConnection(ConnectInfo connectionInfo) throws SQLException {
118+
final Properties properties = getProperties(connectionInfo);
119+
return DriverManager.getConnection(connectionInfo.getUrl(), properties);
120+
}
121+
100122
public FieldType classifyColumnType(String typeName) {
101123
switch (typeName.toUpperCase()) {
102124
case "INT":
@@ -118,4 +140,24 @@ public FieldType classifyColumnType(String typeName) {
118140
}
119141
}
120142

143+
public Properties getProperties(ConnectInfo connectionInfo) {
144+
final Properties properties = new Properties();
145+
String url = connectionInfo.getUrl().toLowerCase();
146+
147+
// 设置通用属性
148+
properties.setProperty("user", connectionInfo.getUserName());
149+
150+
// 针对 Presto 和 Trino ssl=false 的情况,不需要设置密码
151+
if (url.startsWith("jdbc:presto") || url.startsWith("jdbc:trino")) {
152+
// 检查是否需要处理 SSL
153+
if (!url.contains("ssl=false")) {
154+
properties.setProperty("password", connectionInfo.getPassword());
155+
}
156+
} else {
157+
// 针对其他数据库类型
158+
properties.setProperty("password", connectionInfo.getPassword());
159+
}
160+
161+
return properties;
162+
}
121163
}

headless/core/src/main/java/com/tencent/supersonic/headless/core/adaptor/db/DbAdaptor.java

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,10 @@ public interface DbAdaptor {
1818

1919
List<String> getDBs(ConnectInfo connectInfo, String catalog) throws SQLException;
2020

21-
List<String> getTables(ConnectInfo connectInfo, String schemaName) throws SQLException;
21+
List<String> getTables(ConnectInfo connectInfo, String catalog, String schemaName)
22+
throws SQLException;
2223

23-
List<DBColumn> getColumns(ConnectInfo connectInfo, String schemaName, String tableName)
24+
List<DBColumn> getColumns(ConnectInfo connectInfo, String catalog, String schemaName, String tableName)
2425
throws SQLException;
2526

2627
FieldType classifyColumnType(String typeName);

0 commit comments

Comments
 (0)