-
Notifications
You must be signed in to change notification settings - Fork 2
/
run_seq2seq.py
88 lines (65 loc) · 2.74 KB
/
run_seq2seq.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
import sys
import logging
import argparse
import data_util
from Seq2SeqModelTF import Seq2SeqModelTF
from config.ConfigHandler import ConfigHandler
__author__ = "roopal_garg"
logging.basicConfig(
stream=sys.stdout,
level=logging.INFO,
format='%(asctime)s : %(message)s',
datefmt='%m/%d/%Y %I:%M:%S %p'
)
BUCKETS = data_util.BUCKETS
EXIT_PHRASE = ConfigHandler.get("exit_term", "model_param")
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--mode", help="train or test mode", default="train", type=str)
args = parser.parse_args()
logging.info("preparing data")
enc_train, dec_train, enc_dev, dec_dev, _, _ = data_util.prepare_datasets()
logging.info("initializing the model")
lr = ConfigHandler.getfloat("learning_rate", "model_param")
vocab_size_enc = ConfigHandler.getint("vocab_size_enc", "model_param")
vocab_size_dec = ConfigHandler.getint("vocab_size_dec", "model_param")
num_layers = ConfigHandler.getint("num_layers", "model_param")
mx_grad_nrm = ConfigHandler.getfloat("mx_grad_nrm", "model_param")
batch_size = ConfigHandler.getint("batch_size", "model_param")
m = ConfigHandler.getint("layer_size", "model_param")
num_samples = ConfigHandler.getint("num_samples", "model_param")
use_lstm = ConfigHandler.get_boolean("use_lstm", "model_param")
mode = args.mode
fwd_only = False
if mode == "test":
fwd_only = True
model = Seq2SeqModelTF(
src_vocab_size=vocab_size_enc, tgt_vocab_size=vocab_size_dec, buckets=BUCKETS, m=m, num_layers=num_layers,
mx_grad_nrm=mx_grad_nrm, batch_size=batch_size, lr=lr, model_name="seq2seq", save_dir="train_log",
use_lstm=use_lstm, num_samples=num_samples, fwd_only=fwd_only
)
mode = args.mode
if mode == "train":
logging.info("mode: training")
test_every = ConfigHandler.getint("test_every", "model_param")
max_train_data_size = ConfigHandler.getint("max_train_data_size", "model_param")
model.fit(
enc_train, dec_train, enc_dev, dec_dev, max_train_data_size=max_train_data_size, test_every=test_every
)
elif mode == "test":
logging.info("mode: testing")
model.restore_latest_model()
logging.info("beginning conversation, your turn first ({} to exit)".format(EXIT_PHRASE))
sys.stdout.write("> ")
sys.stdout.flush()
sentence = sys.stdin.readline()
while sentence:
if sentence == EXIT_PHRASE:
logging.info("ending conversation, have a good day!")
break
model.test(sentence)
sys.stdout.write("> ")
sys.stdout.flush()
sentence = sys.stdin.readline()
if __name__ == "__main__":
main()