-
Notifications
You must be signed in to change notification settings - Fork 5
/
distill.py
123 lines (105 loc) · 2.96 KB
/
distill.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
import json
import requests
import hashlib
import sys
import sqlite3
# A local database to store the result of all queries
#
# Schema: results(key text, turns text, results text)
# - key: hash of the conversation
# - turns: all human queries in the conversation
# - results: llama-70b-chat result, together with human prompts
# in a format that can be used for fine-tuning
#
con = sqlite3.connect("sharegpt.db")
cur = con.cursor()
# A file that contains all human prompts, in a format of
# [["..", ".."], # first conversation
# ["..", ".."] # second conversation
# ]
#
conversations = json.load(open(sys.argv[1]))
# output file
#
out_file = open(sys.argv[1] + "output.json", "w")
# Together API end point
endpoint = 'https://api.together.xyz/inference'
APIKEY = "PUT YOUR TOGETHER API KEY HERE"
# For each conversation
for turns in conversations:
# Create a signature for the conversation
#
m = hashlib.sha256()
m.update(" ".join(turns).encode())
key = m.hexdigest()
# Check if this conversation is already in the database
# - If so, skip
# - Otherwise, query
#
res = cur.execute(f"SELECT results FROM results WHERE id='{key}';")
fetched = res.fetchone()
if fetched != None:
continue
failed = 0 # has the querying process ever fail
prompt = "" # prompt
# for each human query in the conversation
#
for turn in turns:
# create the prompt by appending human query
#
prompt = prompt + f" [INST] {turn} [/INST] "
# query Together API
#
res = requests.post(endpoint, json={
"model": "togethercomputer/llama-2-70b-chat",
"max_tokens": 1024,
"prompt": prompt,
"request_type": "language-model-inference",
"temperature": 0.7,
"top_p": 0.7,
"top_k": 50,
"repetition_penalty": 1,
"stop": [
"[INST]"
],
"safety_model": "",
"repetitive_penalty": 1
}, headers={
"Authorization": "Bearer " + APIKEY,
})
# parse out the response
#
try:
response = res.json()["output"]["choices"][0]["text"]
except:
failed = 1
print(res.__repr__())
break
# append respond to the conversation
#
prompt = prompt + response
# if all queries succeed, insert the result to DB
# (also print it out)
#
if failed == 0:
cur.execute(f"INSERT INTO results VALUES (?, ?, ?);", (key, turns.__repr__(), prompt))
con.commit()
print(prompt)
print("----------------------------")
else:
print("## Failed!!", prompt)
print("----------------------------")
# Dump all result out
res = cur.execute(f"SELECT results FROM results;")
for fetched in res.fetchall():
text = fetched[0]
# Some basic cleaning and stripping
text = text.replace("[INST]", " [INST]")
for i in range(0, 20):
if f"[INST] {i} / {i}-" in text:
text = text.replace(f"INST] {i} / {i}-", "INST] ")
if f"[INST] {i} / {i}" in text:
text = text.replace(f"INST] {i} / {i}", "INST] ")
text = text.strip()
out_file.write(json.dumps({"text": text}) + "\n")
out_file.close()