-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathgenerate_synthetic_queries.py
76 lines (56 loc) · 2.79 KB
/
generate_synthetic_queries.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
import re
from evaluation import get_relevance_label_df
from evaluation import jaccard_similarity
from evaluation import levenstein_distance
import sys
sys.path.append('../question_generator/')
from questiongenerator import QuestionGenerator
def remove_urls(text):
text = re.sub(r'http\S+', '', text)
text = re.sub(r'www\S+', '', text)
return text
class SyntheticQueryGenerator(object):
def __init__(self, model_dir=None, query_type='faq', min_conf_score=0.60, answer_style='sentences', num_questions=None):
if query_type not in {'faq', 'user_query'}:
raise ValueError('error, query_type not exists')
self.query_type = query_type
self.min_conf_score = min_conf_score
self.answer_style = answer_style
self.num_questions = num_questions
self.qg = QuestionGenerator(model_dir=model_dir)
def generate_synthetic_qas(self, text):
qas = self.qg.generate(
text, answer_style=self.answer_style, num_questions=self.num_questions
)
return qas
def generate_synthetic_query_answer_pairs(self, query_answer_pair_filepath):
relevance_label_df = get_relevance_label_df(query_answer_pair_filepath)
if self.query_type == 'faq':
relevance_label_df = relevance_label_df[relevance_label_df['query_type'] == 'faq']
elif self.query_type == 'user_query':
relevance_label_df = relevance_label_df[relevance_label_df['query_type'] == 'user_query']
synthetic_query_answer_pairs = []
for _, row in relevance_label_df.iterrows():
answer = row['answer']
question = row['question']
label = row['label']
_id = row['id']
answer = remove_urls(answer)
t5_qas = self.generate_synthetic_qas(answer)
t5_questions = [item['question'] for item in t5_qas if item['confidence'] >= self.min_conf_score]
t5_questions = list(set(t5_questions))
if t5_questions:
for t5_question in t5_questions:
jc_sim = jaccard_similarity(question, t5_question)
lv_dist = levenstein_distance(question, t5_question)
data = dict()
data['label'] = label
data['query_type'] = "synthetic"
data['org_question'] = question
data['question'] = t5_question
data['answer'] = answer
data['jc_sim'] = "{0:.4f}".format(jc_sim)
data['lv_dist'] = "{0:.4f}".format(lv_dist)
data['id'] = _id
synthetic_query_answer_pairs.append(data)
return synthetic_query_answer_pairs