This repository has been archived by the owner on Apr 28, 2024. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 11
/
completion.py
74 lines (61 loc) · 2.21 KB
/
completion.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
import torch
import re
from model import get_model
from abstractions import Choice
from constants import (
STARCODER_TOKENS,
LLAMA_TOKENS,
INFILL,
DEVICE,
)
model, tokenizer = get_model()
def get_inputs(prompt):
return tokenizer(
prompt, return_tensors="pt", padding=True, return_token_type_ids=False
).to(DEVICE)
def get_prefix_suffix(prompt):
return (prompt.split(INFILL) + ["", ""])[:2]
def get_outputs(payload, prompt):
with torch.no_grad():
outputs = model.generate(
**get_inputs(prompt),
top_k=payload.top_k,
top_p=payload.top_p,
num_return_sequences=payload.num_return_sequences,
do_sample=True,
temperature=payload.temperature,
max_new_tokens=payload.max_tokens,
pad_token_id=tokenizer.eos_token_id,
)
return outputs
def get_starcoder_completion(payload):
prefix, suffix = get_prefix_suffix(payload.prompt)
prompt = f"{STARCODER_TOKENS['PRE']}{prefix}{STARCODER_TOKENS['SUF']}{suffix}{STARCODER_TOKENS['MID']}"
outputs = get_outputs(payload, prompt)
decoded = tokenizer.decode(outputs[0], skip_special_tokens=False)
start = decoded.find(STARCODER_TOKENS["MID"]) + len(STARCODER_TOKENS["MID"])
end = decoded.find(STARCODER_TOKENS["EOD"], start) or len(decoded)
completion = decoded[start:end]
try:
if not completion:
text = ""
if payload.one_line:
text = completion.splitlines()[0] or completion.splitlines()[1]
except:
text = ""
return [Choice(text=text)]
def get_llama_completion(payload):
prefix, suffix = get_prefix_suffix(payload.prompt)
prompt = f"{LLAMA_TOKENS['PRE']} {prefix} {LLAMA_TOKENS['SUF']}{suffix} {LLAMA_TOKENS['MID']}"
outputs = get_outputs(payload, prompt)
choices = []
for output in outputs:
text = tokenizer.decode(output, skip_special_tokens=False)
match = re.search(r"<MID>(.*)", text)
if match:
completion = match.group(1)
completion = completion.replace("<EOT></s>", "")
choices.append(Choice(text=completion.rstrip()))
else:
choices.append(Choice(text=""))
return choices