In [3]:
import sys

sys.path.append('..')
import numpy as np
from dataset import sequence
from common.optimizer import Adam
from common.trainer import Trainer
from common.util import eval_seq2seq
from attention_seq2seq import AttentionSeq2seq
from seq2seq import Seq2seq
from peeky_seq2seq import PeekySeq2seq

In [4]:
# 读入数据
(x_train, t_train), (x_test, t_test) = sequence.load_data('date.txt')
char_to_id, id_to_char = sequence.get_vocab()

In [5]:
# 反转输入语句
x_train, x_test = x_train[:, ::-1], x_test[:, ::-1]

In [6]:
# 设定超参数
vocab_size = len(char_to_id)
wordvec_size = 16
hidden_size = 256
batch_size = 128
max_epoch = 10
max_grad = 5.0

In [7]:
model = AttentionSeq2seq(vocab_size, wordvec_size, hidden_size)
optimizer = Adam()
trainer = Trainer(model, optimizer)

In [8]:
acc_list = []
for epoch in range(max_epoch):
    trainer.fit(x_train, t_train, max_epoch=1,
                batch_size=batch_size, max_grad=max_grad)
    correct_num = 0
    for i in range(len(x_test)):
        question, correct = x_test[[i]], t_test[[i]]
        verbose = i < 10
        correct_num += eval_seq2seq(model, question, correct,
                                    id_to_char, verbose, is_reverse=True)
    acc = float(correct_num) / len(x_test)
    acc_list.append(acc)
    print('val acc %.3f%%' % (acc * 100))

| epoch 1 |  iter 1 / 351 | time 0[s] | loss 4.08
| epoch 1 |  iter 21 / 351 | time 6[s] | loss 3.09
| epoch 1 |  iter 41 / 351 | time 12[s] | loss 1.90
| epoch 1 |  iter 61 / 351 | time 18[s] | loss 1.72
| epoch 1 |  iter 81 / 351 | time 24[s] | loss 1.46
| epoch 1 |  iter 101 / 351 | time 30[s] | loss 1.19
| epoch 1 |  iter 121 / 351 | time 36[s] | loss 1.14
| epoch 1 |  iter 141 / 351 | time 42[s] | loss 1.09
| epoch 1 |  iter 161 / 351 | time 48[s] | loss 1.06
| epoch 1 |  iter 181 / 351 | time 54[s] | loss 1.04
| epoch 1 |  iter 201 / 351 | time 60[s] | loss 1.03
| epoch 1 |  iter 221 / 351 | time 66[s] | loss 1.02
| epoch 1 |  iter 241 / 351 | time 72[s] | loss 1.02
| epoch 1 |  iter 261 / 351 | time 78[s] | loss 1.01
| epoch 1 |  iter 281 / 351 | time 84[s] | loss 1.00
| epoch 1 |  iter 301 / 351 | time 91[s] | loss 1.00
| epoch 1 |  iter 321 / 351 | time 97[s] | loss 1.00
| epoch 1 |  iter 341 / 351 | time 103[s] | loss 1.00
Q 10/15/94                     
T 1994-10-15
X 1978-0

KeyboardInterrupt: 

In [None]:
model.save_params()