-
Notifications
You must be signed in to change notification settings - Fork 28
/
lm.py
64 lines (49 loc) · 1.97 KB
/
lm.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
import pickle
import os
import time
class LM(object):
def __init__(self, cache_file):
self.cache_file = cache_file
self.cache_dict = self.load_cache()
self.model = None
self.add_n = 0
def load_model(self):
# load the model and put it as self.model
raise NotImplementedError()
def generate(self, prompt, sample_idx=0, max_sequence_length=2048, max_output_length=128):
prompt = prompt.strip() # it's important not to end with a whitespace
cache_key = f"{prompt}_{sample_idx}"
if cache_key in self.cache_dict:
return self.cache_dict[cache_key]
if self.model is None:
self.load_model()
if prompt.endswith(" True or False?\nAnswer:"):
generated = self._generate(prompt, max_sequence_length=max_sequence_length, max_output_length=1)
else:
generated = self._generate(prompt, max_sequence_length=max_sequence_length, max_output_length=max_output_length)
self.cache_dict[cache_key] = generated
self.add_n += 1
return generated
def save_cache(self):
if self.add_n == 0:
return
# load the latest cache first, since if there were other processes running in parallel, cache might have been updated
for k, v in self.load_cache().items():
self.cache_dict[k] = v
with open(self.cache_file, "wb") as f:
pickle.dump(self.cache_dict, f)
def load_cache(self, allow_retry=True):
if os.path.exists(self.cache_file):
while True:
try:
with open(self.cache_file, "rb") as f:
cache = pickle.load(f)
break
except Exception:
if not allow_retry:
assert False
print ("Pickle Error: Retry in 5sec...")
time.sleep(5)
else:
cache = {}
return cache