Skip to content

Commit 91e4b51

Browse files
committed
[fix]Fix unit test cases.
1 parent bf3213e commit 91e4b51

File tree

22 files changed

+197
-115
lines changed

22 files changed

+197
-115
lines changed

chat/server/src/main/java/com/tencent/supersonic/chat/server/pojo/ParseContext.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,8 @@ public ParseContext(ChatParseReq request, ChatParseResp response) {
1919
}
2020

2121
public boolean enableNL2SQL() {
22-
return Objects.nonNull(agent) && agent.containsDatasetTool()&&response.getSelectedParses().size() == 0;
22+
return Objects.nonNull(agent) && agent.containsDatasetTool()
23+
&& response.getSelectedParses().size() == 0;
2324
}
2425

2526
public boolean enableLLM() {

common/src/main/java/com/tencent/supersonic/common/jsqlparser/SqlReplaceHelper.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -727,7 +727,7 @@ public static String replaceSqlByExpression(String tableName, String sql,
727727
List<PlainSelect> plainSelects = SqlSelectHelper.getPlainSelects(plainSelectList);
728728
for (PlainSelect plainSelect : plainSelects) {
729729
if (Objects.nonNull(plainSelect.getFromItem())) {
730-
Table table = (Table) plainSelect.getFromItem();
730+
Table table = SqlSelectHelper.getTable(plainSelect.getFromItem());
731731
if (table.getName().equals(tableName)) {
732732
replacePlainSelectByExpr(plainSelect, replace);
733733
if (SqlSelectHelper.hasAggregateFunction(plainSelect)) {

common/src/main/java/com/tencent/supersonic/common/jsqlparser/SqlSelectHelper.java

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -723,6 +723,44 @@ public static Table getTable(String sql) {
723723
return null;
724724
}
725725

726+
public static Table getTable(FromItem fromItem) {
727+
Table table = null;
728+
if (fromItem instanceof Table) {
729+
table = (Table) fromItem;
730+
} else if (fromItem instanceof ParenthesedSelect) {
731+
ParenthesedSelect parenthesedSelect = (ParenthesedSelect) fromItem;
732+
if (parenthesedSelect.getSelect() instanceof PlainSelect) {
733+
PlainSelect subSelect = (PlainSelect) parenthesedSelect.getSelect();
734+
table = getTable(subSelect.getSelectBody());
735+
} else if (parenthesedSelect.getSelect() instanceof SetOperationList) {
736+
table = getTable(parenthesedSelect.getSelect());
737+
}
738+
}
739+
return table;
740+
}
741+
742+
public static Table getTable(Select select) {
743+
if (select == null) {
744+
return null;
745+
}
746+
List<PlainSelect> plainSelectList = getWithItem(select);
747+
if (!CollectionUtils.isEmpty(plainSelectList)) {
748+
List<PlainSelect> selectList = new ArrayList<>(plainSelectList);
749+
Table table = getTable(selectList.get(0));
750+
return table;
751+
}
752+
if (select instanceof PlainSelect) {
753+
PlainSelect plainSelect = (PlainSelect) select;
754+
return getTable(plainSelect.getFromItem());
755+
} else if (select instanceof SetOperationList) {
756+
SetOperationList setOperationList = (SetOperationList) select;
757+
if (!CollectionUtils.isEmpty(setOperationList.getSelects())) {
758+
return getTable(setOperationList.getSelects().get(0));
759+
}
760+
}
761+
return null;
762+
}
763+
726764
public static String getDbTableName(String sql) {
727765
Table table = getTable(sql);
728766
return table.getFullyQualifiedName();

headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/request/QueryStructReq.java

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,11 @@
33
import com.google.common.collect.Lists;
44
import com.tencent.supersonic.common.jsqlparser.SqlAddHelper;
55
import com.tencent.supersonic.common.jsqlparser.SqlReplaceHelper;
6-
import com.tencent.supersonic.common.pojo.*;
6+
import com.tencent.supersonic.common.pojo.Aggregator;
7+
import com.tencent.supersonic.common.pojo.Constants;
8+
import com.tencent.supersonic.common.pojo.DateConf;
9+
import com.tencent.supersonic.common.pojo.Filter;
10+
import com.tencent.supersonic.common.pojo.Order;
711
import com.tencent.supersonic.common.pojo.enums.AggOperatorEnum;
812
import com.tencent.supersonic.common.pojo.enums.QueryType;
913
import com.tencent.supersonic.common.util.ContextUtils;
@@ -21,14 +25,22 @@
2125
import net.sf.jsqlparser.parser.CCJSqlParserUtil;
2226
import net.sf.jsqlparser.schema.Column;
2327
import net.sf.jsqlparser.schema.Table;
24-
import net.sf.jsqlparser.statement.select.*;
28+
import net.sf.jsqlparser.statement.select.GroupByElement;
29+
import net.sf.jsqlparser.statement.select.Limit;
30+
import net.sf.jsqlparser.statement.select.Offset;
31+
import net.sf.jsqlparser.statement.select.OrderByElement;
32+
import net.sf.jsqlparser.statement.select.ParenthesedSelect;
33+
import net.sf.jsqlparser.statement.select.PlainSelect;
34+
import net.sf.jsqlparser.statement.select.SelectItem;
2535
import org.apache.commons.codec.digest.DigestUtils;
2636
import org.apache.commons.lang3.StringUtils;
2737
import org.springframework.util.CollectionUtils;
2838

2939
import java.util.ArrayList;
40+
import java.util.HashSet;
3041
import java.util.List;
3142
import java.util.Objects;
43+
import java.util.Set;
3244
import java.util.stream.Collectors;
3345

3446
@Data
@@ -176,7 +188,7 @@ private String buildSql(QueryStructReq queryStructReq, boolean isBizName)
176188

177189
private List<SelectItem<?>> buildSelectItems(QueryStructReq queryStructReq) {
178190
List<SelectItem<?>> selectItems = new ArrayList<>();
179-
List<String> groups = queryStructReq.getGroups();
191+
Set<String> groups = new HashSet<>(queryStructReq.getGroups());
180192

181193
if (!CollectionUtils.isEmpty(groups)) {
182194
for (String group : groups) {
@@ -236,7 +248,7 @@ private List<OrderByElement> buildOrderByElements(QueryStructReq queryStructReq)
236248
}
237249

238250
private GroupByElement buildGroupByElement(QueryStructReq queryStructReq) {
239-
List<String> groups = queryStructReq.getGroups();
251+
Set<String> groups = new HashSet<>(queryStructReq.getGroups());
240252
if ((!CollectionUtils.isEmpty(groups) && !queryStructReq.getAggregators().isEmpty())
241253
|| !queryStructReq.getMetricFilters().isEmpty()) {
242254
GroupByElement groupByElement = new GroupByElement();

headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/request/SqlExecuteReq.java

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,9 +23,9 @@ public class SqlExecuteReq {
2323
private Integer limit = 1000;
2424

2525
public String getSql() {
26-
if(StringUtils.isNotBlank(sql)){
27-
sql=sql.replaceAll("^[\\n]+|[\\n]+$", "");
28-
sql=StringUtils.removeEnd(sql,";");
26+
if (StringUtils.isNotBlank(sql)) {
27+
sql = sql.replaceAll("^[\\n]+|[\\n]+$", "");
28+
sql = StringUtils.removeEnd(sql, ";");
2929
}
3030

3131
return String.format(LIMIT_WRAPPER, sql, limit);

headless/chat/src/main/java/com/tencent/supersonic/headless/chat/knowledge/file/FileHandlerImpl.java

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,7 @@ public PageInfo<DictValueResp> queryDictValue(String fileName, DictValueReq dict
8686
}
8787

8888
private PageInfo<DictValueResp> getDictValueRespPagWithKey(String fileName,
89-
DictValueReq dictValueReq) {
89+
DictValueReq dictValueReq) {
9090
PageInfo<DictValueResp> dictValueRespPageInfo = new PageInfo<>();
9191
dictValueRespPageInfo.setPageSize(dictValueReq.getPageSize());
9292
dictValueRespPageInfo.setPageNum(dictValueReq.getCurrent());
@@ -95,7 +95,7 @@ private PageInfo<DictValueResp> getDictValueRespPagWithKey(String fileName,
9595
Integer startLine = 1;
9696
List<DictValueResp> dictValueRespList =
9797
getFileData(filePath, startLine, fileLineNum.intValue()).stream().filter(
98-
dictValue -> dictValue.getValue().contains(dictValueReq.getKeyValue()))
98+
dictValue -> dictValue.getValue().contains(dictValueReq.getKeyValue()))
9999
.collect(Collectors.toList());
100100
if (CollectionUtils.isEmpty(dictValueRespList)) {
101101
dictValueRespPageInfo.setList(new ArrayList<>());
@@ -118,7 +118,7 @@ private PageInfo<DictValueResp> getDictValueRespPagWithKey(String fileName,
118118
}
119119

120120
private PageInfo<DictValueResp> getDictValueRespPagWithoutKey(String fileName,
121-
DictValueReq dictValueReq) {
121+
DictValueReq dictValueReq) {
122122
PageInfo<DictValueResp> dictValueRespPageInfo = new PageInfo<>();
123123
String filePath = localFileConfig.getDictDirectoryLatest() + FILE_SPILT + fileName;
124124
Long fileLineNum = getFileLineNum(filePath);
@@ -175,7 +175,7 @@ private List<DictValueResp> getFileData(String filePath, Integer startLine, Inte
175175
private DictValueResp convert2Resp(String lineStr) {
176176
DictValueResp dictValueResp = new DictValueResp();
177177
if (StringUtils.isNotEmpty(lineStr)) {
178-
lineStr=StringUtils.stripStart(lineStr,null);
178+
lineStr = StringUtils.stripStart(lineStr, null);
179179
String[] itemArray = lineStr.split("\\s+");
180180
if (Objects.nonNull(itemArray) && itemArray.length >= 3) {
181181
dictValueResp.setValue(itemArray[0].replace("#", " "));

headless/chat/src/main/java/com/tencent/supersonic/headless/chat/mapper/EmbeddingMatchStrategy.java

Lines changed: 18 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ public class EmbeddingMatchStrategy extends BatchMatchStrategy<EmbeddingResult>
6363

6464
@Override
6565
public List<EmbeddingResult> detect(ChatQueryContext chatQueryContext, List<S2Term> terms,
66-
Set<Long> detectDataSetIds) {
66+
Set<Long> detectDataSetIds) {
6767
if (chatQueryContext == null || CollectionUtils.isEmpty(detectDataSetIds)) {
6868
log.warn("Invalid input parameters: context={}, dataSetIds={}", chatQueryContext,
6969
detectDataSetIds);
@@ -92,7 +92,7 @@ public List<EmbeddingResult> detect(ChatQueryContext chatQueryContext, List<S2Te
9292
* Perform enhanced detection using LLM
9393
*/
9494
private List<EmbeddingResult> detectWithLLM(ChatQueryContext chatQueryContext,
95-
Set<Long> detectDataSetIds) {
95+
Set<Long> detectDataSetIds) {
9696
try {
9797
String queryText = chatQueryContext.getRequest().getQueryText();
9898
if (StringUtils.isBlank(queryText)) {
@@ -126,7 +126,7 @@ private Set<String> extractValidSegments(String text) {
126126

127127
@Override
128128
public List<EmbeddingResult> detectByBatch(ChatQueryContext chatQueryContext,
129-
Set<Long> detectDataSetIds, Set<String> detectSegments) {
129+
Set<Long> detectDataSetIds, Set<String> detectSegments) {
130130
return detectByBatch(chatQueryContext, detectDataSetIds, detectSegments, false);
131131
}
132132

@@ -140,7 +140,7 @@ public List<EmbeddingResult> detectByBatch(ChatQueryContext chatQueryContext,
140140
* @return List of embedding results
141141
*/
142142
public List<EmbeddingResult> detectByBatch(ChatQueryContext chatQueryContext,
143-
Set<Long> detectDataSetIds, Set<String> detectSegments, boolean useLlm) {
143+
Set<Long> detectDataSetIds, Set<String> detectSegments, boolean useLlm) {
144144
Set<EmbeddingResult> results = ConcurrentHashMap.newKeySet();
145145
int embeddingMapperBatch = Integer
146146
.valueOf(mapperConfig.getParameterValue(MapperConfig.EMBEDDING_MAPPER_BATCH));
@@ -168,10 +168,11 @@ public List<EmbeddingResult> detectByBatch(ChatQueryContext chatQueryContext,
168168
variable.put("retrievedInfo", JSONObject.toJSONString(results));
169169

170170
Prompt prompt = PromptTemplate.from(LLM_FILTER_PROMPT).apply(variable);
171-
ChatModelConfig chatModelConfig=null;
172-
if(chatQueryContext.getRequest().getChatAppConfig()!=null
173-
&& chatQueryContext.getRequest().getChatAppConfig().containsKey("REWRITE_MULTI_TURN")){
174-
chatModelConfig=chatQueryContext.getRequest().getChatAppConfig().get("REWRITE_MULTI_TURN").getChatModelConfig();
171+
ChatModelConfig chatModelConfig = null;
172+
if (chatQueryContext.getRequest().getChatAppConfig() != null && chatQueryContext
173+
.getRequest().getChatAppConfig().containsKey("REWRITE_MULTI_TURN")) {
174+
chatModelConfig = chatQueryContext.getRequest().getChatAppConfig()
175+
.get("REWRITE_MULTI_TURN").getChatModelConfig();
175176
}
176177
ChatLanguageModel chatLanguageModel = ModelProvider.getChatModel(chatModelConfig);
177178
String response = chatLanguageModel.generate(prompt.toUserMessage().singleText());
@@ -200,7 +201,7 @@ public List<EmbeddingResult> detectByBatch(ChatQueryContext chatQueryContext,
200201
* @return Callable task
201202
*/
202203
private Callable<Void> createTask(ChatQueryContext chatQueryContext, Set<Long> detectDataSetIds,
203-
List<String> queryTextsSub, Set<EmbeddingResult> results, boolean useLlm) {
204+
List<String> queryTextsSub, Set<EmbeddingResult> results, boolean useLlm) {
204205
return () -> {
205206
List<EmbeddingResult> oneRoundResults = detectByQueryTextsSub(detectDataSetIds,
206207
queryTextsSub, chatQueryContext, useLlm);
@@ -221,7 +222,7 @@ private Callable<Void> createTask(ChatQueryContext chatQueryContext, Set<Long> d
221222
* @return List of embedding results for this batch
222223
*/
223224
private List<EmbeddingResult> detectByQueryTextsSub(Set<Long> detectDataSetIds,
224-
List<String> queryTextsSub, ChatQueryContext chatQueryContext, boolean useLlm) {
225+
List<String> queryTextsSub, ChatQueryContext chatQueryContext, boolean useLlm) {
225226
Map<Long, List<Long>> modelIdToDataSetIds = chatQueryContext.getModelIdToDataSetIds();
226227

227228
// Get configuration parameters
@@ -243,12 +244,12 @@ private List<EmbeddingResult> detectByQueryTextsSub(Set<Long> detectDataSetIds,
243244

244245
// Process results
245246
List<EmbeddingResult> collect = retrieveQueryResults.stream().peek(result -> {
246-
if (!useLlm && CollectionUtils.isNotEmpty(result.getRetrieval())) {
247-
result.getRetrieval()
248-
.removeIf(retrieval -> !result.getQuery().contains(retrieval.getQuery())
249-
&& retrieval.getSimilarity() < threshold);
250-
}
251-
}).filter(result -> CollectionUtils.isNotEmpty(result.getRetrieval()))
247+
if (!useLlm && CollectionUtils.isNotEmpty(result.getRetrieval())) {
248+
result.getRetrieval()
249+
.removeIf(retrieval -> !result.getQuery().contains(retrieval.getQuery())
250+
&& retrieval.getSimilarity() < threshold);
251+
}
252+
}).filter(result -> CollectionUtils.isNotEmpty(result.getRetrieval()))
252253
.flatMap(result -> result.getRetrieval().stream()
253254
.map(retrieval -> convertToEmbeddingResult(result, retrieval)))
254255
.collect(Collectors.toList());
@@ -267,7 +268,7 @@ private List<EmbeddingResult> detectByQueryTextsSub(Set<Long> detectDataSetIds,
267268
* @return Converted EmbeddingResult
268269
*/
269270
private EmbeddingResult convertToEmbeddingResult(RetrieveQueryResult queryResult,
270-
Retrieval retrieval) {
271+
Retrieval retrieval) {
271272
EmbeddingResult embeddingResult = new EmbeddingResult();
272273
BeanUtils.copyProperties(retrieval, embeddingResult);
273274
embeddingResult.setDetectWord(queryResult.getQuery());

headless/chat/src/main/java/com/tencent/supersonic/headless/chat/mapper/KeywordMapper.java

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ public void doMap(ChatQueryContext chatQueryContext) {
5151
}
5252

5353
private void convertMapResultToMapInfo(List<HanlpMapResult> mapResults,
54-
ChatQueryContext chatQueryContext, List<S2Term> terms) {
54+
ChatQueryContext chatQueryContext, List<S2Term> terms) {
5555
if (CollectionUtils.isEmpty(mapResults)) {
5656
return;
5757
}
@@ -87,14 +87,15 @@ private void convertMapResultToMapInfo(List<HanlpMapResult> mapResults,
8787
.similarity(hanlpMapResult.getSimilarity())
8888
.detectWord(hanlpMapResult.getDetectWord()).build();
8989
// doDimValueAliasLogic 将维度值别名进行替换成真实维度值
90-
doDimValueAliasLogic(schemaElementMatch,chatQueryContext.getSemanticSchema().getDimensionValues());
90+
doDimValueAliasLogic(schemaElementMatch,
91+
chatQueryContext.getSemanticSchema().getDimensionValues());
9192
addToSchemaMap(chatQueryContext.getMapInfo(), dataSetId, schemaElementMatch);
9293
}
9394
}
9495
}
9596

9697
private void doDimValueAliasLogic(SchemaElementMatch schemaElementMatch,
97-
List<SchemaElement> dimensionValues) {
98+
List<SchemaElement> dimensionValues) {
9899
SchemaElement element = schemaElementMatch.getElement();
99100
if (SchemaElementType.VALUE.equals(element.getType())) {
100101
Long dimId = element.getId();
@@ -126,7 +127,7 @@ private void doDimValueAliasLogic(SchemaElementMatch schemaElementMatch,
126127
}
127128

128129
private void convertMapResultToMapInfo(ChatQueryContext chatQueryContext,
129-
List<DatabaseMapResult> mapResults) {
130+
List<DatabaseMapResult> mapResults) {
130131
for (DatabaseMapResult match : mapResults) {
131132
SchemaElement schemaElement = match.getSchemaElement();
132133
Set<Long> regElementSet =
@@ -153,8 +154,8 @@ private Set<Long> getRegElementSet(SchemaMapInfo schemaMap, SchemaElement schema
153154
return new HashSet<>();
154155
}
155156
return elements.stream().filter(
156-
elementMatch -> SchemaElementType.METRIC.equals(elementMatch.getElement().getType())
157-
|| SchemaElementType.DIMENSION.equals(elementMatch.getElement().getType()))
157+
elementMatch -> SchemaElementType.METRIC.equals(elementMatch.getElement().getType())
158+
|| SchemaElementType.DIMENSION.equals(elementMatch.getElement().getType()))
158159
.map(elementMatch -> elementMatch.getElement().getId()).collect(Collectors.toSet());
159160
}
160161
}

headless/server/src/main/java/com/tencent/supersonic/headless/server/aspect/DimValueAspect.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -124,7 +124,7 @@ public Object handleSqlDimValue(ProceedingJoinPoint joinPoint) throws Throwable
124124
sql = SqlReplaceHelper.replaceValue(sql, filedNameToValueMap);
125125
log.debug("correctorSql after replacing:{}", sql);
126126
querySqlReq.setSql(sql);
127-
querySqlReq.getSqlInfo().setQuerySQL(sql);
127+
// querySqlReq.getSqlInfo().setQuerySQL(sql);
128128
Map<String, Map<String, String>> techNameToBizName = getTechNameToBizName(dimensions);
129129

130130
SemanticQueryResp queryResultWithColumns = (SemanticQueryResp) joinPoint.proceed();

headless/server/src/main/java/com/tencent/supersonic/headless/server/persistence/repository/impl/DimensionRepositoryImpl.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,7 @@ public List<DimensionDO> getDimension(DimensionFilter dimensionFilter) {
8888
}
8989
if (StringUtils.isNotBlank(dimensionFilter.getKey())) {
9090
String key = dimensionFilter.getKey();
91-
queryWrapper.and(qw->qw.lambda().like(DimensionDO::getName, key).or()
91+
queryWrapper.and(qw -> qw.lambda().like(DimensionDO::getName, key).or()
9292
.like(DimensionDO::getBizName, key).or().like(DimensionDO::getDescription, key)
9393
.or().like(DimensionDO::getAlias, key).or()
9494
.like(DimensionDO::getCreatedBy, key));

0 commit comments

Comments
 (0)