In [1]:
from pointer_net import PointerNetwork
import sys
import numpy as np
if int(sys.version[0]) == 2:
    from io import open


def read_data(path):
    with open(path, 'r', encoding='utf-8') as f:
        return f.read()
# end function


def build_map(data):
    specials = ['<GO>',  '<EOS>', '<PAD>', '<UNK>']
    chars = list(set([char for line in data.split('\n') for char in line]))
    chars = sorted(chars)
    idx2char = {idx: char for idx, char in enumerate(specials+chars)}
    char2idx = {char: idx for idx, char in idx2char.items()}
    return idx2char, char2idx
# end function


def preprocess_data(max_len):
    X_data = read_data('temp/letters_source.txt')
    Y_data = read_data('temp/letters_target.txt')

    X_idx2char, X_char2idx = build_map(X_data)
    print("==> Word Index Built")

    x_unk = X_char2idx['<UNK>']
    x_eos = X_char2idx['<EOS>']
    x_pad = X_char2idx['<PAD>']

    X_indices = []
    X_seq_len = []
    Y_indices = []
    Y_seq_len = []

    for x_line, y_line in zip(X_data.split('\n'), Y_data.split('\n')):
        x_chars = [X_char2idx.get(char, x_unk) for char in x_line]
        _x_chars = x_chars + [x_eos] + [x_pad]* (max_len-1-len(x_chars))
        
        y_chars = [X_char2idx.get(char, x_unk) for char in y_line]
        _y_chars = y_chars + [x_eos] + [x_pad]* (max_len-1-len(y_chars))
        target = [_x_chars.index(y) for y in _y_chars] # we are predicting the positions

        X_indices.append(_x_chars)
        Y_indices.append(target)
        X_seq_len.append(len(x_chars)+1)
        Y_seq_len.append(len(y_chars)+1)

    X_indices = np.array(X_indices)
    Y_indices = np.array(Y_indices)
    X_seq_len = np.array(X_seq_len)
    Y_seq_len = np.array(Y_seq_len)
    print("==> Sequence Padded")

    return X_indices, X_seq_len, Y_indices, Y_seq_len, X_char2idx, X_idx2char
# end function


def train_test_split(X_indices, X_seq_len, Y_indices, Y_seq_len, BATCH_SIZE):
    X_train = X_indices[BATCH_SIZE:]
    X_train_len = X_seq_len[BATCH_SIZE:]
    Y_train = Y_indices[BATCH_SIZE:]
    Y_train_len = Y_seq_len[BATCH_SIZE:]

    X_test = X_indices[:BATCH_SIZE]
    X_test_len = X_seq_len[:BATCH_SIZE]
    Y_test = Y_indices[:BATCH_SIZE]
    Y_test_len = Y_seq_len[:BATCH_SIZE]

    return (X_train, X_train_len, Y_train, Y_train_len), (X_test, X_test_len, Y_test, Y_test_len)
# end function


def main():
    BATCH_SIZE = 128
    MAX_LEN = 15
    X_indices, X_seq_len, Y_indices, Y_seq_len, X_char2idx, X_idx2char = preprocess_data(MAX_LEN)
    
    (X_train, X_train_len, Y_train, Y_train_len), (X_test, X_test_len, Y_test, Y_test_len) \
        = train_test_split(X_indices, X_seq_len, Y_indices, Y_seq_len, BATCH_SIZE)
    
    model = PointerNetwork(
        max_len = MAX_LEN,
        rnn_size = 50,
        X_word2idx = X_char2idx,
        embedding_dim = 15)
    
    model.fit(X_train, X_train_len, Y_train, Y_train_len,
        val_data=(X_test, X_test_len, Y_test, Y_test_len), batch_size=BATCH_SIZE, n_epoch=200)
    model.infer('common', X_idx2char)
    model.infer('apple', X_idx2char)
    model.infer('zhedong', X_idx2char)
# end main


if __name__ == '__main__':
    main()


==> Word Index Built
==> Sequence Padded
Epoch 1/200 | Batch 0/77 | train_loss: 2.698 | test_loss: 2.694
Epoch 1/200 | Batch 50/77 | train_loss: 2.312 | test_loss: 2.284
Epoch 2/200 | Batch 0/77 | train_loss: 2.051 | test_loss: 2.007
Epoch 2/200 | Batch 50/77 | train_loss: 1.801 | test_loss: 1.710
Epoch 3/200 | Batch 0/77 | train_loss: 1.651 | test_loss: 1.591
Epoch 3/200 | Batch 50/77 | train_loss: 1.418 | test_loss: 1.314
Epoch 4/200 | Batch 0/77 | train_loss: 1.113 | test_loss: 1.086
Epoch 4/200 | Batch 50/77 | train_loss: 0.923 | test_loss: 0.842
Epoch 5/200 | Batch 0/77 | train_loss: 0.777 | test_loss: 0.758
Epoch 5/200 | Batch 50/77 | train_loss: 0.705 | test_loss: 0.658
Epoch 6/200 | Batch 0/77 | train_loss: 0.631 | test_loss: 0.623
Epoch 6/200 | Batch 50/77 | train_loss: 0.610 | test_loss: 0.577
Epoch 7/200 | Batch 0/77 | train_loss: 0.555 | test_loss: 0.553
Epoch 7/200 | Batch 50/77 | train_loss: 0.548 | test_loss: 0.519
Epoch 8/200 | Batch 0/77 | train_loss: 0.503 | test_loss

Epoch 63/200 | Batch 50/77 | train_loss: 0.158 | test_loss: 0.151
Epoch 64/200 | Batch 0/77 | train_loss: 0.153 | test_loss: 0.147
Epoch 64/200 | Batch 50/77 | train_loss: 0.139 | test_loss: 0.147
Epoch 65/200 | Batch 0/77 | train_loss: 0.139 | test_loss: 0.143
Epoch 65/200 | Batch 50/77 | train_loss: 0.137 | test_loss: 0.145
Epoch 66/200 | Batch 0/77 | train_loss: 0.137 | test_loss: 0.141
Epoch 66/200 | Batch 50/77 | train_loss: 0.135 | test_loss: 0.145
Epoch 67/200 | Batch 0/77 | train_loss: 0.135 | test_loss: 0.139
Epoch 67/200 | Batch 50/77 | train_loss: 0.134 | test_loss: 0.143
Epoch 68/200 | Batch 0/77 | train_loss: 0.133 | test_loss: 0.138
Epoch 68/200 | Batch 50/77 | train_loss: 0.132 | test_loss: 0.143
Epoch 69/200 | Batch 0/77 | train_loss: 0.131 | test_loss: 0.138
Epoch 69/200 | Batch 50/77 | train_loss: 0.131 | test_loss: 0.142
Epoch 70/200 | Batch 0/77 | train_loss: 0.129 | test_loss: 0.138
Epoch 70/200 | Batch 50/77 | train_loss: 0.130 | test_loss: 0.139
Epoch 71/200 | Ba

Epoch 126/200 | Batch 0/77 | train_loss: 0.082 | test_loss: 0.109
Epoch 126/200 | Batch 50/77 | train_loss: 0.090 | test_loss: 0.109
Epoch 127/200 | Batch 0/77 | train_loss: 0.081 | test_loss: 0.108
Epoch 127/200 | Batch 50/77 | train_loss: 0.090 | test_loss: 0.118
Epoch 128/200 | Batch 0/77 | train_loss: 0.087 | test_loss: 0.118
Epoch 128/200 | Batch 50/77 | train_loss: 0.091 | test_loss: 0.113
Epoch 129/200 | Batch 0/77 | train_loss: 0.082 | test_loss: 0.117
Epoch 129/200 | Batch 50/77 | train_loss: 0.095 | test_loss: 0.109
Epoch 130/200 | Batch 0/77 | train_loss: 0.081 | test_loss: 0.110
Epoch 130/200 | Batch 50/77 | train_loss: 0.088 | test_loss: 0.110
Epoch 131/200 | Batch 0/77 | train_loss: 0.084 | test_loss: 0.108
Epoch 131/200 | Batch 50/77 | train_loss: 0.102 | test_loss: 0.196
Epoch 132/200 | Batch 0/77 | train_loss: 0.087 | test_loss: 0.114
Epoch 132/200 | Batch 50/77 | train_loss: 0.093 | test_loss: 0.114
Epoch 133/200 | Batch 0/77 | train_loss: 0.081 | test_loss: 0.109
Epo

Epoch 188/200 | Batch 0/77 | train_loss: 0.066 | test_loss: 0.097
Epoch 188/200 | Batch 50/77 | train_loss: 0.071 | test_loss: 0.107
Epoch 189/200 | Batch 0/77 | train_loss: 0.066 | test_loss: 0.094
Epoch 189/200 | Batch 50/77 | train_loss: 0.070 | test_loss: 0.107
Epoch 190/200 | Batch 0/77 | train_loss: 0.066 | test_loss: 0.099
Epoch 190/200 | Batch 50/77 | train_loss: 0.069 | test_loss: 0.108
Epoch 191/200 | Batch 0/77 | train_loss: 0.067 | test_loss: 0.095
Epoch 191/200 | Batch 50/77 | train_loss: 0.067 | test_loss: 0.108
Epoch 192/200 | Batch 0/77 | train_loss: 0.067 | test_loss: 0.095
Epoch 192/200 | Batch 50/77 | train_loss: 0.065 | test_loss: 0.106
Epoch 193/200 | Batch 0/77 | train_loss: 0.067 | test_loss: 0.097
Epoch 193/200 | Batch 50/77 | train_loss: 0.068 | test_loss: 0.104
Epoch 194/200 | Batch 0/77 | train_loss: 0.065 | test_loss: 0.096
Epoch 194/200 | Batch 50/77 | train_loss: 0.064 | test_loss: 0.104
Epoch 195/200 | Batch 0/77 | train_loss: 0.066 | test_loss: 0.099
Epo