In [1]:
from seq2seq_birnn import Seq2Seq
import sys
if int(sys.version[0]) == 2:
    from io import open


def read_data(path):
    with open(path, 'r', encoding='utf-8') as f:
        return f.read()
# end function read_data


def build_map(data):
    specials = ['<GO>',  '<EOS>', '<PAD>', '<UNK>']
    chars = list(set([char for line in data.split('\n') for char in line]))
    idx2char = {idx: char for idx, char in enumerate(specials + chars)}
    char2idx = {char: idx for idx, char in idx2char.items()}
    return idx2char, char2idx
# end function build_map


def preprocess_data():
    X_data = read_data('temp/letters_source.txt')
    Y_data = read_data('temp/letters_target.txt')

    X_idx2char, X_char2idx = build_map(X_data)
    Y_idx2char, Y_char2idx = build_map(Y_data)

    x_unk = X_char2idx['<UNK>']
    y_unk = Y_char2idx['<UNK>']
    y_eos = Y_char2idx['<EOS>']

    X_indices = [[X_char2idx.get(char, x_unk) for char in line] for line in X_data.split('\n')]
    Y_indices = [[Y_char2idx.get(char, y_unk) for char in line] + [y_eos] for line in Y_data.split('\n')]

    return X_indices, Y_indices, X_char2idx, Y_char2idx, X_idx2char, Y_idx2char
# end function preprocess_data


def main():
    BATCH_SIZE = 128
    X_indices, Y_indices, X_char2idx, Y_char2idx, X_idx2char, Y_idx2char = preprocess_data()
    X_train = X_indices[BATCH_SIZE:]
    Y_train = Y_indices[BATCH_SIZE:]
    X_test = X_indices[:BATCH_SIZE]
    Y_test = Y_indices[:BATCH_SIZE]

    model = Seq2Seq(
        rnn_size = 50,
        n_layers = 2,
        X_word2idx = X_char2idx,
        encoder_embedding_dim = 15,
        Y_word2idx = Y_char2idx,
        decoder_embedding_dim = 15,
    )
    model.fit(X_train, Y_train, val_data=(X_test, Y_test), batch_size=BATCH_SIZE)
    model.infer('common', X_idx2char, Y_idx2char)
    model.infer('apple', X_idx2char, Y_idx2char)
    model.infer('zhedong', X_idx2char, Y_idx2char)
# end function main


if __name__ == '__main__':
    main()

Epoch 1/60 | Batch 0/77 | train_loss: 3.399 | test_loss: 3.396
Epoch 1/60 | Batch 50/77 | train_loss: 2.818 | test_loss: 2.804
Epoch 2/60 | Batch 0/77 | train_loss: 2.346 | test_loss: 2.373
Epoch 2/60 | Batch 50/77 | train_loss: 1.831 | test_loss: 1.813
Epoch 3/60 | Batch 0/77 | train_loss: 1.592 | test_loss: 1.620
Epoch 3/60 | Batch 50/77 | train_loss: 1.415 | test_loss: 1.367
Epoch 4/60 | Batch 0/77 | train_loss: 1.237 | test_loss: 1.265
Epoch 4/60 | Batch 50/77 | train_loss: 1.139 | test_loss: 1.102
Epoch 5/60 | Batch 0/77 | train_loss: 0.992 | test_loss: 1.033
Epoch 5/60 | Batch 50/77 | train_loss: 0.934 | test_loss: 0.915
Epoch 6/60 | Batch 0/77 | train_loss: 0.807 | test_loss: 0.862
Epoch 6/60 | Batch 50/77 | train_loss: 0.775 | test_loss: 0.766
Epoch 7/60 | Batch 0/77 | train_loss: 0.662 | test_loss: 0.717
Epoch 7/60 | Batch 50/77 | train_loss: 0.646 | test_loss: 0.640
Epoch 8/60 | Batch 0/77 | train_loss: 0.550 | test_loss: 0.601
Epoch 8/60 | Batch 50/77 | train_loss: 0.539 | t