In [1]:
# -*- coding: utf-8 -*-
import sys
import chseg
import numpy as np
import tensorflow as tf

from collections import Counter
from multihead_attn_clf import Tagger
from sklearn.metrics import classification_report


SEQ_LEN = 50
N_CLASS = 4 # B: 0, M: 1, E: 2, S: 3
N_EPOCH = 1
BATCH_SIZE = 128
sample = '我来到大学读书，希望学到知识'
py = int(sys.version[0])


def to_train_seq(*args):
    data = []
    for x in args:
        data.append(iter_seq(x))
    return data


def to_test_seq(*args):
    data = []
    for x in args:
        x = x[: (len(x) - len(x) % SEQ_LEN)]
        data.append(np.reshape(x, [-1, SEQ_LEN]))
    return data


def iter_seq(x, text_iter_step=10):
    return np.array([x[i : i+SEQ_LEN] for i in range(0, len(x)-SEQ_LEN, text_iter_step)])


if __name__ == '__main__':
    x_train, y_train, x_test, y_test, vocab_size, word2idx, idx2word = chseg.load_data()
    X_train, Y_train = to_train_seq(x_train, y_train)
    X_test, Y_test = to_test_seq(x_test, y_test)
    print('Vocab size: %d' % vocab_size)

    clf = Tagger(vocab_size, N_CLASS, SEQ_LEN)
    clf.fit(X_train, Y_train, n_epoch=N_EPOCH, batch_size=BATCH_SIZE)

    y_pred = clf.predict(X_test, batch_size=BATCH_SIZE)
    print(classification_report(Y_test.ravel(), y_pred.ravel(), target_names=['B', 'M', 'E', 'S']))

    chars = list(sample) if py == 3 else list(sample.decode('utf-8'))
    _test = [word2idx[w] for w in sample] + [0] * (SEQ_LEN-len(sample))
    labels = clf.infer(_test, len(sample))
    labels = labels[:len(sample)]
    res = ''
    for i, l in enumerate(labels):
        c = sample[i] if py == 3 else sample.decode('utf-8')[i]
        if l == 2 or l == 3:
            c += ' '
        res += c
    print(res)

Vocab size: 4533
Instructions for updating:
`NHWC` for data_format is deprecated, use `NWC` instead


  "Converting sparse IndexedSlices to a dense Tensor of unknown shape. "


Data Shuffled
Epoch 1/1 | Step 0/1142 | train_loss: 100.7424 | train_acc: 0.2237 | lr: 0.0050
Epoch 1/1 | Step 50/1142 | train_loss: 23.0107 | train_acc: 0.8127 | lr: 0.0045
Epoch 1/1 | Step 100/1142 | train_loss: 17.9541 | train_acc: 0.8562 | lr: 0.0041
Epoch 1/1 | Step 150/1142 | train_loss: 12.8941 | train_acc: 0.8913 | lr: 0.0037
Epoch 1/1 | Step 200/1142 | train_loss: 10.6346 | train_acc: 0.9172 | lr: 0.0033
Epoch 1/1 | Step 250/1142 | train_loss: 9.7816 | train_acc: 0.9166 | lr: 0.0030
Epoch 1/1 | Step 300/1142 | train_loss: 9.0972 | train_acc: 0.9252 | lr: 0.0027
Epoch 1/1 | Step 350/1142 | train_loss: 8.2243 | train_acc: 0.9298 | lr: 0.0025
Epoch 1/1 | Step 400/1142 | train_loss: 6.7713 | train_acc: 0.9439 | lr: 0.0022
Epoch 1/1 | Step 450/1142 | train_loss: 6.7307 | train_acc: 0.9389 | lr: 0.0020
Epoch 1/1 | Step 500/1142 | train_loss: 6.2553 | train_acc: 0.9448 | lr: 0.0018
Epoch 1/1 | Step 550/1142 | train_loss: 6.0360 | train_acc: 0.9458 | lr: 0.0016
Epoch 1/1 | Step 600/11