Skip to content

Commit 7e6639d

Browse files
committed
(improvement)(common|headless|chat|auth) 鉴权优化与召回优化
1 修复生成的用户token 一生成就失效的问题 2 如果用户设置的token ,需校验是否数据库存在,因为用户可设置一年的token 有泄露风险 3 结果解析优化, 去除不可以解析的情况,解析问题需要改写后的问, 4 召回样例,用相似度,保住至少有一个样例是高相似度的 5 数据集召回,填加完全匹配格式筛选逻辑
1 parent 0721df2 commit 7e6639d

File tree

8 files changed

+84
-18
lines changed

8 files changed

+84
-18
lines changed

auth/authentication/src/main/java/com/tencent/supersonic/auth/authentication/adaptor/DefaultUserAdaptor.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -222,8 +222,9 @@ public UserToken generateToken(String name, String userName, long expireTime) {
222222
new UserWithPassword(userDO.getId(), userDO.getName(), userDO.getDisplayName(),
223223
userDO.getEmail(), userDO.getPassword(), userDO.getIsAdmin());
224224

225+
// 使用令牌名称作为生成key ,这样可以区分正常请求和api 请求,api 的令牌失效时间很长,需考虑令牌泄露的情况
225226
String token =
226-
tokenService.generateToken(UserWithPassword.convert(userWithPassword), expireTime);
227+
tokenService.generateToken(UserWithPassword.convert(userWithPassword),"SysDbToken:"+name, (new Date().getTime() + expireTime));
227228
UserTokenDO userTokenDO = saveUserToken(name, userName, token, expireTime);
228229
return convertUserToken(userTokenDO);
229230
}

auth/authentication/src/main/java/com/tencent/supersonic/auth/authentication/utils/TokenService.java

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,10 @@
66

77
import com.tencent.supersonic.auth.api.authentication.config.AuthenticationConfig;
88
import com.tencent.supersonic.auth.api.authentication.pojo.UserWithPassword;
9+
import com.tencent.supersonic.auth.authentication.persistence.dataobject.UserTokenDO;
10+
import com.tencent.supersonic.auth.authentication.persistence.repository.UserRepository;
911
import com.tencent.supersonic.common.pojo.exception.AccessException;
12+
import com.tencent.supersonic.common.util.ContextUtils;
1013
import io.jsonwebtoken.Claims;
1114
import io.jsonwebtoken.Jwts;
1215
import io.jsonwebtoken.SignatureAlgorithm;
@@ -71,6 +74,7 @@ public String generateAppUserToken(HttpServletRequest request) {
7174
return generateToken(UserWithPassword.convert(appUser), request);
7275
}
7376

77+
7478
public Optional<Claims> getClaims(HttpServletRequest request) {
7579
String token = request.getHeader(authenticationConfig.getTokenHttpHeaderKey());
7680
String appKey = getAppKey(request);
@@ -90,6 +94,13 @@ private Optional<Claims> getClaims(String token, HttpServletRequest request) {
9094

9195
public Optional<Claims> getClaims(String token, String appKey) {
9296
try {
97+
if(StringUtils.isNotBlank(appKey)&&appKey.startsWith("SysDbToken:")) {// 如果是配置的长期令牌,需校验数据库是否存在该配置
98+
UserRepository userRepository = ContextUtils.getBean(UserRepository.class);
99+
UserTokenDO dbToken= userRepository.getUserTokenByName(appKey.substring("SysDbToken:".length()));
100+
if(dbToken==null||!dbToken.getToken().equals(token.replace("Bearer ",""))) {
101+
throw new AccessException("Token does not exist :" + appKey);
102+
}
103+
}
93104
String tokenSecret = getTokenSecret(appKey);
94105
Claims claims =
95106
Jwts.parser().setSigningKey(tokenSecret.getBytes(StandardCharsets.UTF_8))
@@ -122,6 +133,16 @@ private String getTokenSecret(String appKey) {
122133
Map<String, String> appKeyToSecretMap = authenticationConfig.getAppKeyToSecretMap();
123134
String secret = appKeyToSecretMap.get(appKey);
124135
if (StringUtils.isBlank(secret)) {
136+
if(StringUtils.isNotBlank(appKey)&&appKey.startsWith("SysDbToken:")) { // 是配置的长期令牌
137+
String realAppKey=appKey.substring("SysDbToken:".length());
138+
String tmp = "WIaO9YRRVt+7QtpPvyWsARFngnEcbaKBk783uGFwMrbJBaochsqCH62L4Kijcb0sZCYoSsiKGV/zPml5MnZ3uQ==";
139+
if(tmp.length()<=realAppKey.length()) {
140+
return realAppKey;
141+
}
142+
else{
143+
return realAppKey+tmp.substring(realAppKey.length());
144+
}
145+
}
125146
throw new AccessException("get secret from appKey failed :" + appKey);
126147
}
127148
return secret;

chat/server/src/main/java/com/tencent/supersonic/chat/server/processor/execute/DataInterpretProcessor.java

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,8 @@ public boolean accept(ExecuteContext executeContext) {
4747
Agent agent = executeContext.getAgent();
4848
ChatApp chatApp = agent.getChatAppConfig().get(APP_KEY);
4949
return Objects.nonNull(chatApp) && chatApp.isEnable()
50-
&& StringUtils.isNotBlank(executeContext.getResponse().getTextResult()); // 如果都没结果,则无法处理,直接跳过
50+
&& StringUtils.isNotBlank(executeContext.getResponse().getTextResult()) // 如果都没结果,则无法处理
51+
&& StringUtils.isBlank(executeContext.getResponse().getTextSummary()); // 如果已经有汇总的结果了,无法再次处理
5152
}
5253

5354
@Override
@@ -57,7 +58,15 @@ public void process(ExecuteContext executeContext) {
5758
ChatApp chatApp = agent.getChatAppConfig().get(APP_KEY);
5859

5960
Map<String, Object> variable = new HashMap<>();
60-
variable.put("question", executeContext.getRequest().getQueryText());
61+
String question = executeContext.getResponse().getTextResult();// 结果解析应该用改写的问题,因为改写的内容信息量更大
62+
if(executeContext.getParseInfo().getProperties()!=null&&
63+
executeContext.getParseInfo().getProperties().containsKey("CONTEXT")){
64+
Map<String,Object> context = (Map<String, Object>) executeContext.getParseInfo().getProperties().get("CONTEXT");
65+
if(context.get("queryText")!=null&&"".equals(context.get("queryText"))){
66+
question = context.get("queryText").toString();
67+
}
68+
}
69+
variable.put("question", question);
6170
variable.put("data", queryResult.getTextResult());
6271

6372
Prompt prompt = PromptTemplate.from(chatApp.getPrompt()).apply(variable);

common/src/main/java/com/hankcs/hanlp/LoadRemoveService.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ public List removeNatures(List value, Set<Long> modelIdOrDataSetIds) {
2121
List<String> resultList = new ArrayList<>(value);
2222
if (!CollectionUtils.isEmpty(modelIdOrDataSetIds)) {
2323
resultList.removeIf(nature -> {
24-
if (Objects.isNull(nature)) {
24+
if (Objects.isNull(nature)||!nature.startsWith("_")) { // 系统的字典是以 _ 开头的, 过滤因引用外部字典导致的异常
2525
return false;
2626
}
2727
Long id = getId(nature);

common/src/main/java/com/tencent/supersonic/common/pojo/Text2SQLExemplar.java

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,4 +22,6 @@ public class Text2SQLExemplar implements Serializable {
2222
private String dbSchema;
2323

2424
private String sql;
25+
26+
protected double similarity; // 传递相似度,可以作为样本筛选的依据
2527
}

common/src/main/java/com/tencent/supersonic/common/service/impl/ExemplarServiceImpl.java

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,10 @@ public List<Text2SQLExemplar> recallExemplars(String collection, String query, i
7272
embeddingService.retrieveQuery(collection, retrieveQuery, num);
7373
results.forEach(ret -> {
7474
ret.getRetrieval().forEach(r -> {
75-
exemplars.add(JsonUtil.mapToObject(r.getMetadata(), Text2SQLExemplar.class));
75+
Text2SQLExemplar tmp = //传递相似度,可以作为样本筛选的依据
76+
JsonUtil.mapToObject(r.getMetadata(), Text2SQLExemplar.class);
77+
tmp.setSimilarity(r.getSimilarity());
78+
exemplars.add(tmp);
7679
});
7780
});
7881

headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/SemanticParseInfo.java

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
import java.util.Map;
1919
import java.util.Objects;
2020
import java.util.Set;
21+
import java.util.stream.Collectors;
2122

2223
import static com.tencent.supersonic.common.pojo.Constants.DEFAULT_DETAIL_LIMIT;
2324
import static com.tencent.supersonic.common.pojo.Constants.DEFAULT_METRIC_LIMIT;
@@ -65,12 +66,23 @@ public int compare(SemanticParseInfo o1, SemanticParseInfo o2) {
6566
DataSetMatchResult mr2 = getDataSetMatchResult(o2.getElementMatches());
6667

6768
double difference = mr1.getMaxDatesetSimilarity() - mr2.getMaxDatesetSimilarity();
68-
if (difference == 0) {
69+
if (Math.abs(difference) < 0.0005) { // 看完全匹配的个数,实践证明,可以用户输入规范后,该逻辑具有优势
70+
if (!o1.getDataSetId().equals(o2.getDataSetId())) {
71+
List<SchemaElementMatch> elementMatches1 = o1.getElementMatches().stream()
72+
.filter(e -> e.getSimilarity() == 1).collect(Collectors.toList());
73+
List<SchemaElementMatch> elementMatches2 = o2.getElementMatches().stream()
74+
.filter(e -> e.getSimilarity() == 1).collect(Collectors.toList());
75+
if (elementMatches1.size() > elementMatches2.size()) {
76+
return -1;
77+
} else if (elementMatches1.size() < elementMatches2.size()) {
78+
return 1;
79+
}
80+
}
6981
difference = mr1.getMaxMetricSimilarity() - mr2.getMaxMetricSimilarity();
70-
if (difference == 0) {
82+
if (Math.abs(difference) < 0.0005) {
7183
difference = mr1.getTotalSimilarity() - mr2.getTotalSimilarity();
7284
}
73-
if (difference == 0) {
85+
if (Math.abs(difference) < 0.0005) {
7486
difference = mr1.getMaxMetricUseCnt() - mr2.getMaxMetricUseCnt();
7587
}
7688
}

headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/llm/PromptHelper.java

Lines changed: 28 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,8 @@
1414
import org.springframework.stereotype.Component;
1515
import org.springframework.util.CollectionUtils;
1616

17-
import java.util.ArrayList;
18-
import java.util.Collections;
19-
import java.util.List;
20-
import java.util.Objects;
17+
import java.util.*;
18+
import java.util.stream.Collectors;
2119

2220
import static com.tencent.supersonic.headless.chat.parser.ParserConfig.*;
2321

@@ -51,13 +49,33 @@ public List<List<Text2SQLExemplar>> getFewShotExemplars(LLMReq llmReq) {
5149
// use random collection of exemplars for each self-consistency inference
5250
for (int i = 0; i < selfConsistencyNumber; i++) {
5351
List<Text2SQLExemplar> shuffledList = new ArrayList<>(exemplars);
54-
// only shuffle the exemplars from config
55-
List<Text2SQLExemplar> subList =
56-
shuffledList.subList(llmReq.getDynamicExemplars().size(), shuffledList.size());
57-
Collections.shuffle(subList);
58-
results.add(shuffledList.subList(0, Math.min(shuffledList.size(), fewShotNumber)));
52+
List<Text2SQLExemplar> same = shuffledList.stream() // 相似度极高的话,先找出来
53+
.filter(e -> e.getSimilarity() > 0.989).collect(Collectors.toList());
54+
List<Text2SQLExemplar> noSame = shuffledList.stream()
55+
.filter(e -> e.getSimilarity() <= 0.989).collect(Collectors.toList());
56+
if ((noSame.size() - same.size()) > fewShotNumber) {// 去除部分最低分
57+
noSame.sort(Comparator.comparingDouble(Text2SQLExemplar::getSimilarity));
58+
noSame = noSame.subList((noSame.size() - fewShotNumber) / 2, noSame.size());
59+
}
60+
Text2SQLExemplar mostSimilar = noSame.get(noSame.size() - 1);
61+
Collections.shuffle(noSame);
62+
List<Text2SQLExemplar> ts;
63+
if (same.size() > 0) {// 一样的话,必须作为提示语
64+
ts = new ArrayList<>();
65+
int needSize = Math.min(noSame.size() + same.size(), fewShotNumber);
66+
if (needSize > same.size()) {
67+
ts.addAll(noSame.subList(0, needSize - same.size()));
68+
}
69+
ts.addAll(same);
70+
} else { // 至少要一个最像的
71+
ts = noSame.subList(0, Math.min(noSame.size(), fewShotNumber));
72+
if (!ts.contains(mostSimilar)) {
73+
ts.remove(ts.size() - 1);
74+
ts.add(mostSimilar);
75+
}
76+
}
77+
results.add(ts);
5978
}
60-
6179
return results;
6280
}
6381

0 commit comments

Comments
 (0)