-
Notifications
You must be signed in to change notification settings - Fork 0
/
predict.py
54 lines (42 loc) · 1.67 KB
/
predict.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
import numpy as np
def decode_sequence(
input_seq,
encoder_model,
decoder_model,
max_words,
max_len,
reverse_input_char_index):
# Encode the input as state vectors.
states_value = encoder_model.predict(input_seq)
# Generate empty target sequence of length 1.
target_seq = np.zeros((1, 1, max_words))
# Populate the first character of target sequence with the start character.
# target_seq[0, 0, target_token_index['\t']] = 1.
# Sampling loop for a batch of sequences
# (to simplify, here we assume a batch of size 1).
stop_condition = False
decoded_sentence = ''
n_words = 0
while not stop_condition:
output_tokens, h, c = decoder_model.predict(
[target_seq] + states_value)
# print(output_tokens.shape)
# Sample a token
sampled_token_index = np.argmax(output_tokens[0, 0, :])
if (sampled_token_index==0 or sampled_token_index==1) and n_words<=4:
sampled_token_index = np.argsort(output_tokens[0, 0, :], axis=-1, kind='quicksort', order=None)[-2]
# print('---')
sampled_char = reverse_input_char_index[sampled_token_index]
# Exit condition: either hit max length
# or find stop character.
if n_words > max_len or sampled_char=='endofsent':
stop_condition = True
else:
decoded_sentence += sampled_char+' '
# Update the target sequence (of length 1).
target_seq = np.zeros((1, 1, max_words))
target_seq[0, 0, sampled_token_index] = 1.
# Update states
states_value = [h, c]
n_words+=1
return decoded_sentence