-
Notifications
You must be signed in to change notification settings - Fork 13
/
search_prompts.py
133 lines (99 loc) · 4.45 KB
/
search_prompts.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
import fire
import json
from nltk import sent_tokenize
from thefuzz import fuzz
from models.gpt3 import GPT3
from data_utils.data_utils import get_n_ents, get_sent, fix_prompt_style
TRANSFORMATIONS_SENT = [['', ''], ['a ', ''], ['the ', '']]
TRANSFORMATIONS_ENT = [
['', ''], ['being', 'is'], ['being', 'are'], ['ing', ''], ['ing', 'e']]
def get_paraphrase_prompt(gpt3, prompt, ent_tuple):
assert get_n_ents(prompt) == len(ent_tuple)
ent_tuple = [ent.lower() for ent in ent_tuple]
sent = get_sent(prompt=prompt, ent_tuple=ent_tuple)
for _ in range(5):
raw_response = gpt3.call(prompt=f'paraphrase:\n{sent}\n')
para_sent = raw_response['choices'][0]['text']
para_sent = sent_tokenize(para_sent)[0]
para_sent = para_sent.strip().strip('.').lower()
print('para_sent:', para_sent)
prompt = para_sent
valid = True
for idx, ent in enumerate(ent_tuple):
# prompt = prompt.replace(ent, f'<ENT{idx}>')
for trans_sent in TRANSFORMATIONS_SENT:
for trans_ent in TRANSFORMATIONS_ENT:
if prompt.count(f'<ENT{idx}>') == 0:
transed_prompt = prompt.replace(*trans_sent)
transed_ent = ent.replace(*trans_ent)
if transed_prompt.count(transed_ent) == 1:
prompt = transed_prompt.replace(
transed_ent, f'<ENT{idx}>')
if prompt.count(f'<ENT{idx}>') != 1:
valid = False
break
if valid:
return prompt
return None
def search_prompts(init_prompts, seed_ent_tuples, similarity_threshold):
gpt3 = GPT3()
cache = {}
prompts = []
while True:
new_prompts = []
for prompt in init_prompts + init_prompts + prompts:
for ent_tuple in seed_ent_tuples:
ent_tuple = [ent.replace('_', ' ') for ent in ent_tuple]
request_str = f'{prompt} ||| {ent_tuple}'
if request_str not in cache or prompt in init_prompts:
cache[request_str] = get_paraphrase_prompt(
gpt3=gpt3, prompt=prompt, ent_tuple=ent_tuple)
para_prompt = cache[request_str]
print(f'prompt: {prompt}\tent_tuple: {ent_tuple}'
f'\t-> para_prompt: {para_prompt}')
if para_prompt is not None and \
para_prompt not in init_prompts + prompts:
new_prompts.append(para_prompt)
if len(set(prompts + new_prompts)) >= 20:
break
if len(new_prompts) == 0:
break
else:
# prompts.extend(new_prompts)
flag = False
for new_prompt in sorted(new_prompts, key=lambda t: len(t)):
if len(prompts) != 0:
max_sim = max([fuzz.ratio(new_prompt, prompt)
for prompt in prompts])
print(f'-- {new_prompt}: {max_sim}')
if len(prompts) == 0 or \
max([fuzz.ratio(new_prompt, prompt)
for prompt in prompts]) < similarity_threshold:
prompts.append(new_prompt)
flag = True
prompts = list(set(prompts))
prompts.sort(key=lambda s: len(s))
if len(prompts) >= 10 or flag == False:
break
for i in range(len(prompts)):
prompts[i] = fix_prompt_style(prompts[i])
return prompts
def main(rel_set='conceptnet', similarity_threshold=75):
relation_info = json.load(open(f'relation_info/{rel_set}.json'))
for rel, info in relation_info.items():
info['init_prompts'] = [
fix_prompt_style(prompt) for prompt in info['init_prompts']]
if 'prompts' not in info or len(info['prompts']) == 0:
info['prompts'] = search_prompts(
init_prompts=info['init_prompts'],
seed_ent_tuples=info['seed_ent_tuples'],
similarity_threshold=similarity_threshold)
for key, value in info.items():
print(f'{key}: {value}')
for prompt in info['prompts']:
print(f'- {prompt}')
print('=' * 50)
output_path = f'relation_info/{rel_set}.json'
json.dump(relation_info, open(output_path, 'w'), indent=4)
if __name__ == '__main__':
fire.Fire(main)