-
Notifications
You must be signed in to change notification settings - Fork 3
/
test.py
87 lines (69 loc) · 2.38 KB
/
test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
import torch
from data_process import MyTokenizer
from data_process import Podcasts as Dataset
from layers import Retrieval
from search.beam import beam_search
from utility import load_from_pkl
def get_data(config, log, mode='test'):
print("get_data test")
tokenizer = MyTokenizer(config)
valid_set = Dataset(
name=config.test,
len_func=lambda x: sum(len(it) for it in x[0]),
config=config,
tokenizer=tokenizer,
log=log,
mode=mode
)
n_valid = len(valid_set)
log.log("There are %d batches in valid data" % n_valid)
return n_valid, valid_set
def get_network(config, log):
# Model Setup
log.log("Building Model")
net = Retrieval(config)
# Loading Parameter
log.log("Loading Parameters")
best_model = torch.load(config.save_path + "/" + config.model)
new_stat_dict = {}
for key, value in best_model["state_dict"].items():
if key.startswith("module."):
new_key = key[7:]
else:
new_key = key
new_stat_dict[new_key] = value
net.load_state_dict(new_stat_dict)
log.log("Parameters Loaded")
net = net.cuda(config.device)
net.eval()
log.log("Finished Build Model")
return net
def test(config, log):
tokenizer = MyTokenizer(config)
config.batch_size = 1
config.mode = "ret"
net = get_network(config, log)
_, valid_set = get_data(config, log)
name = "hidden_v_" + str(config.kernel_size) + "_" + str(config.stride) + "_" + config.window_type + ".pkl"
h_v = load_from_pkl(name)
suffix = "_" + str(config.stride) + "_" + str(config.beam_size) + "_" + str(config.length_penalty)
f = open("summary" + suffix + ".txt", "w")
torch.cuda.empty_cache()
for batch_idx, batch_data in enumerate(valid_set):
print(batch_idx)
answer, _ = beam_search(net, batch_data, config, h=h_v, output_others=True)
ans = []
first = True
for sent in answer[2]:
if first:
ans.append(sent)
first = False
else:
ans[-1].append(sent[0])
ans.append(sent[1:])
summary_text = " [SSPLIT] ".join(tokenizer.decode(sent) for sent in ans)
if summary_text.endswith(" [SSPLIT] "):
summary_text = summary_text[:-10]
print(summary_text)
print(summary_text, file=f)
f.close()