/
generate_instances.py
169 lines (157 loc) · 6.23 KB
/
generate_instances.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
import os
import json
import random
import tqdm
import re
import argparse
import pandas as pd
from collections import OrderedDict
from gpt3_api import make_requests as make_gpt3_requests
from templates.instance_gen_template import output_first_template_for_clf, input_first_template_for_gen
random.seed(42)
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"--batch_dir",
type=str,
required=True,
help="The directory where the batch is stored.",
)
parser.add_argument(
"--input_file",
type=str,
default="machine_generated_instructions.jsonl"
)
parser.add_argument(
"--output_file",
type=str,
default="machine_generated_instances.jsonl",
)
parser.add_argument(
"--num_instructions",
type=int,
help="if specified, only generate instance input for this many instructions",
)
parser.add_argument(
"--max_instances_to_generate",
type=int,
default=5,
help="The max number of instances to generate for each instruction.",
)
parser.add_argument(
"--generation_tasks_only",
action="store_true",
help="If specified, only do for generation tasks.",
)
parser.add_argument(
"--classification_tasks_only",
action="store_true",
help="If specified, only do for classification tasks.",
)
parser.add_argument(
"--engine",
type=str,
default="davinci",
help="The engine to use."
)
parser.add_argument(
"--request_batch_size",
type=int,
default=5,
help="The number of requests to send in a batch."
)
parser.add_argument(
"--api_key",
type=str,
help="The API key to use. If not specified, the key will be read from the environment variable OPENAI_API_KEY."
)
parser.add_argument(
"--organization",
type=str,
help="The organization to use. If not specified, the default organization id will be used."
)
return parser.parse_args()
if __name__ == '__main__':
args = parse_args()
with open(os.path.join(args.batch_dir, args.input_file)) as fin:
lines = fin.readlines()
if args.num_instructions is not None:
lines = lines[:args.num_instructions]
tasks = []
for line in lines:
data = json.loads(line)
if "metadata" in data:
data["instruction_metadata"] = data["metadata"]
del data["metadata"]
tasks.append(data)
task_clf_types = {}
with open(os.path.join(args.batch_dir, "is_clf_or_not_davinci_template_1.jsonl")) as fin:
for line in fin:
data = json.loads(line)
task_clf_types[data["instruction"]] = data["is_classification"].strip() in ["Yes", "yes", "YES"]
if args.classification_tasks_only:
tasks = [task for task in tasks if task_clf_types[task["instruction"]]]
if args.generation_tasks_only:
tasks = [task for task in tasks if not task_clf_types[task["instruction"]]]
output_path = os.path.join(args.batch_dir, args.output_file)
existing_requests = {}
if os.path.exists(output_path):
with open(output_path) as fin:
for line in tqdm.tqdm(fin):
try:
data = json.loads(line)
existing_requests[data["instruction"]] = data
except:
pass
print(f"Loaded {len(existing_requests)} existing requests")
progress_bar = tqdm.tqdm(total=len(tasks))
with open(output_path, "w") as fout:
for batch_idx in range(0, len(tasks), args.request_batch_size):
batch = tasks[batch_idx: batch_idx + args.request_batch_size]
if all(d["instruction"] in existing_requests for d in batch):
for d in batch:
data = existing_requests[d["instruction"]]
data = OrderedDict(
(k, data[k]) for k in \
["instruction", "raw_instances", "instance_metadata", "instruction_metadata",
"most_similar", "avg_similarity_score"]
)
fout.write(json.dumps(data, ensure_ascii=False) + "\n")
else:
prompts = []
for task in batch:
if task_clf_types[task["instruction"]]:
prompt = output_first_template_for_clf + " " + task["instruction"].strip() + "\n"
prompts.append(prompt)
else:
prompt = input_first_template_for_gen + " " + task["instruction"].strip() + "\n"
prompts.append(prompt)
results = make_gpt3_requests(
engine=args.engine,
prompts=prompts,
# because the clf template is longer, we need to decrease the max_tokens
max_tokens=300 if any(task_clf_types[task["instruction"]] for task in batch) else 350,
temperature=0,
top_p=0,
frequency_penalty=0,
presence_penalty=1.5,
stop_sequences=[f"Example {args.max_instances_to_generate + 1}", "Task:"],
logprobs=1,
n=1,
best_of=1,
api_key=args.api_key,
organization=args.organization)
for i in range(len(batch)):
data = batch[i]
data["instance_metadata"] = results[i]
if results[i]["response"] is not None:
data["raw_instances"] = results[i]["response"]["choices"][0]["text"]
else:
data["raw_instances"] = ""
data = OrderedDict(
(k, data[k]) for k in \
["instruction", "raw_instances", "instance_metadata", "instruction_metadata",
"most_similar", "avg_similarity_score"]
)
fout.write(json.dumps(data, ensure_ascii=False) + "\n")
progress_bar.update(len(batch))