# Sequence Labeling using LSTM+CRF

In [28]:
UNK = "$UNK$"
NUM = "$NUM$"
NONE = "O"

Load takens (e.g., words or tags) and create their indices

In [29]:
def load_token2index(file_name):
    token2idx = {}
    with open(file_name) as f:
        for idx, token in enumerate(f):
            token = token.strip()
            token2idx[token] = idx
    return token2idx

In [30]:
word2index = load_token2index('data/word_vocab.txt')
tag2index = load_token2index('data/pos_tag_vocab.txt')

In [31]:
print(len(word2index))
print(len(tag2index))

22949
45


In [229]:
def get_processing_token(token2index, lowercase=False):
    
    def f(token):
        
        if lowercase:
            token = token.lower()
        if token.isdigit():
            token = NUM
        
        if token in token2index:
            token = token2index[token]
        else:
            token = token2index[UNK]
        
        return token

    return f

In [230]:
process_word_f = get_processing_token(word2index, True)
process_tag_f = get_processing_token(tag2index, False)

In [406]:
class Dataset(object):
    def __init__(self, file_name, processing_word, processing_tag, max_iter=None):
        self.file_name = file_name
        self.processing_word = processing_word
        self.processing_tag = processing_tag
        self.max_iter = max_iter
        
    def __iter__(self):
        words = []
        pos_tags = []
        niter = 0
        with open(self.file_name) as f:
            for line in f:
                line = line.strip()
                if len(line)==0 or line.startswith("-DOCSTART-"):
                    if len(words)!=0:
                        niter+=1
                        if self.max_iter is not None and niter > self.max_iter:
                            break
                        yield (words, pos_tags)
                        words, pos_tags = [], []
                else:
                    ls = line.split(' ')
                    word, pos_tag = ls[0], ls[1]
                    if self.processing_word is not None:
                        word = self.processing_word(word)
                    if self.processing_tag is not None:
                        pos_tag = self.processing_tag(pos_tag)
                    words += [word]
                    pos_tags += [pos_tag]


In [410]:
def minibatch(dataset, batch_size):
    xbatch, ybatch = [], []
    for word, tag in dataset:
        if len(xbatch) == batch_size:
            yield xbatch, ybatch
            xbatch, ybatch = [], []
            
        xbatch += [word]
        ybatch += [tag]
    
    if len(xbatch) != 0:
        yield xbatch, ybatch  

In [411]:
dataset = Dataset('data/train.txt', process_word_f, process_tag_f, 10)

In [443]:
# for i, (xbatch, ybatch) in enumerate(minibatch(dataset, 4)):
#     print('batch_'+str(i), len(xbatch), len(ybatch))
#     print(xbatch)
#     print(ybatch)

In [362]:
def pad_sequences(sequences):
    max_len = max(map(lambda x:len(x), sequences))
    
    sequences_pad, sequences_length = [], []
    for seq in sequences:
        seq = list(seq)
        seq_ = seq[:max_len] + [0]*max(max_len - len(seq), 0)
        sequences_pad += [seq_]
        sequences_length += [min(len(seq), max_len)]
    return sequences_pad, sequences_length

In [444]:
# seq = [[19176, 11783, 4637, 15283, 18989, 3386, 19850, 5296, 15298], [1889, 2142], [11319, 9722]]
# seq_pad, seq_len = pad_sequences(seq)
# print(seq_pad)
# print(seq_len)

## 2. Build Model

* Encoder 
    1. placeholder for input
    2. embedding for transforming index to vector
    3. LSTM for transforming input embedding into internal representation
* Decoder
    1. decoder for inferencing
    2. decoder for training

In [22]:
import tensorflow as tf
import numpy as np

In [753]:
word_num = len(word2index)
dim_word = 50
hidden_size_lstm = 300
tag_num = len(tag2index)
use_crf = True
pretrained_embedding = True

### 2.1 Model Input

* training samples with shape `(batch_size, max_sequence_length)`
* labels with shape `(batch_size)` (or `(batch_size, max_sequence_length)` if using one hot encoding)
* sequence lengths
* hyperparameters
    * dropout rate
    * learning rate

In [754]:
tf.reset_default_graph()

In [755]:
word_ids = tf.placeholder(shape=[None, None], dtype=tf.int32, name='word_ids')
sequence_lengths = tf.placeholder(shape=[None], dtype=tf.int32, name='sequence_lengths')
labels = tf.placeholder(shape=[None, None], dtype=tf.int32, name='label')

learning_rate = tf.placeholder( dtype=tf.float32, name='learning_rate')
dropout_rate = tf.placeholder( dtype=tf.float32, name='dropout_rate')

### 2.2 Word Embedding

In [756]:
if not pretrained_embedding:
    word_embedding_matrix = tf.get_variable(shape=[word_num, dim_word], 
                                        dtype=tf.float32, 
                                        name='word_embedding_matrix')
else:
    filename_glove = "data/glove.6B/glove.6B.{}d.txt".format(dim_word)
    embedding_matrix = np.zeros((word_num, dim_word))
    with open(filename_glove) as f:
        for line in f:
            line = line.strip().split(' ')
            word = line[0]
            if word in word2index:
                embedding_matrix[word2index[word]] = np.asarray(line[1:])
    
    word_embedding_matrix = tf.Variable(embedding_matrix,
                                        dtype=tf.float32, 
                                        name='pretrained_word_embedding_matrix',
                                        trainable=True)

word_embedding = tf.nn.embedding_lookup(word_embedding_matrix, 
                                        word_ids, 
                                        name='word_embedding_lookup')

### 2.3 Encoder

In [757]:
with tf.variable_scope('bi_lstm_encoder'):
    cell_fw = tf.contrib.rnn.LSTMCell(hidden_size_lstm)
    cell_bw = tf.contrib.rnn.LSTMCell(hidden_size_lstm)
    
    (output_fw, output_bw), _ = tf.nn.bidirectional_dynamic_rnn(cell_fw, 
                                                                cell_bw, 
                                                                word_embedding, 
                                                                sequence_length=sequence_lengths, 
                                                                dtype=tf.float32)
    output = tf.concat([output_fw, output_bw], axis=-1)
    output = tf.nn.dropout(output, dropout_rate)

### 2.4 Decoder

In [758]:
with tf.variable_scope('decoder'):
    W = tf.get_variable(shape=[2*hidden_size_lstm, tag_num], dtype=tf.float32, name='proj_W')
    b = tf.get_variable(shape=[tag_num], dtype=tf.float32, name='proj_b')
    
    time_steps = tf.shape(output)[1]
    output = tf.reshape(output, [-1, 2*hidden_size_lstm])
    pred = tf.matmul(output, W) + b
    logits = tf.reshape(pred, [-1, time_steps, tag_num])
    logits_shape = tf.shape(logits)

In [759]:
with tf.variable_scope('projection'):
    if not use_crf:
        label_pred = tf.cast(tf.argmax(logits, axis=-1), tf.int32)

In [760]:
# logits_shape [20 40 45]
# trans_params_shape [45 45]
# log_likelihood_shape [20]

with tf.variable_scope('loss'):
    if use_crf:
        log_likelihood, trans_params = tf.contrib.crf.crf_log_likelihood(logits, 
                                                                         labels, 
                                                                         sequence_lengths) 
        loss = tf.reduce_mean(-log_likelihood)
        log_likelihood_shape = tf.shape(log_likelihood)
        trans_params_shape = tf.shape(trans_params)
    else:
        losses = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=logits, labels=labels)
        mask = tf.sequence_mask(sequence_lengths)
        losses = tf.boolean_mask(losses, mask)
        loss = tf.reduce_mean(losses)

### 2.5 Optimizer

In [761]:
optimizer = tf.train.AdamOptimizer(learning_rate)
train_op = optimizer.minimize(loss)

  "Converting sparse IndexedSlices to a dense Tensor of unknown shape. "


## 3. Train Model

In [762]:
def validation(valid_dataset, sess, batch_size):
    accs = []
    ret = []
    for xbatch, labels in minibatch(valid_dataset, batch_size):
        word_seq, sequence_len = pad_sequences(xbatch)
        
        feed = {
                word_ids: word_seq,
                sequence_lengths: sequence_len,
                dropout_rate:1.0
               }
        
        if use_crf:
            viterbi_sequences = []
            logits_v, trans_params_v = sess.run([logits, trans_params], feed_dict=feed)
            for logit, seq_length in zip(logits_v, sequence_len):
                logit = logit[:seq_length]
                viterbi_seq, viterbi_score = tf.contrib.crf.viterbi_decode(logit, trans_params_v)
                viterbi_sequences += [viterbi_seq]
            labels_pred_v = viterbi_sequences
        else:
            labels_pred_v = sess.run(label_pred, feed_dict=feed)
            
        for words, lab, lab_pred, seq_length in zip(xbatch, labels, labels_pred_v, sequence_len):
            lab = lab[:seq_length]
            lab_pred = lab_pred[:seq_length]
            acc = [a==b for (a, b) in zip(lab, lab_pred)]
            ret.append((words, lab, lab_pred, acc))
            accs+=acc

    overall_acc = np.mean(accs)
    
    return overall_acc, ret

In [763]:
dataset = Dataset('data/train.txt', process_word_f, process_tag_f)
valid_dataset = Dataset('data/valid.txt', process_word_f, process_tag_f)

In [764]:
nepochs         = 10
kr              = 0.7 # keep rate
batch_size      = 20
lr              = 0.01 # learning rate

In [765]:
saver = tf.train.Saver()

In [772]:
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    for ep in range(nepochs):
        losses = []
        i = 0
        for xbatch, ybatch in minibatch(dataset, batch_size):
            i+=1
            word_seq, sequence_len = pad_sequences(xbatch)
            target_seq, _ = pad_sequences(ybatch)

            # build feed dictionary
            feed = {
                word_ids: word_seq,
                labels: target_seq,
                sequence_lengths: sequence_len,
                learning_rate:lr,
                dropout_rate:dr
            }
        
#             logits_sh, trans_params_shape2, log_likelihood_shape2 = sess.run([logits_shape, trans_params_shape, 
#                                                                               log_likelihood_shape], feed_dict=feed)
    
#             print('logits_shape', logits_sh)
#             print('trans_params_shape', trans_params_shape2)
#             print('log_likelihood_shape', log_likelihood_shape2)
            
            _, train_loss = sess.run([train_op, loss], feed_dict=feed)
            losses += [train_loss]
            
            if i % 10 ==0:
                print('ep:', ep, 'iter:', i, 'loss:', np.mean(losses))
            if i % 50 ==0:
                acc, _ = validation(valid_dataset, sess, batch_size)
                print('accuracy', acc)
    saver.save(sess, "checkpoints/ner.ckpt")

# 4. Test

In [767]:
index2word = {index : word for word, index in word2index.items()}
index2tag = {index : tag for tag, index in tag2index.items()}

In [768]:
def parse(index2token, token_indices):
    ret = [index2token[idx] for idx in token_indices]
    return ret

In [769]:
parse(index2tag, [0, 3,2,1])

['RP', 'NNS', ')', 'EX']

In [770]:
test_dataset = Dataset('data/test.txt', process_word_f, process_tag_f, 5)

In [771]:
with tf.Session() as sess:
    saver.restore(sess, "checkpoints/ner.ckpt")
    overall_acc, result = validation(test_dataset, sess, 20)
    
    for (words, lab, lab_pred, acc) in result:
        print('------------------------------------------------------')
        print(parse(index2word, words))
        print(parse(index2tag, lab))
        print(parse(index2tag, lab_pred))
        print('accuracy:', acc)
    print("acc", overall_acc)

INFO:tensorflow:Restoring parameters from checkpoints/ner.ckpt
------------------------------------------------------
['soccer', '-', 'japan', 'get', 'lucky', 'win', ',', 'china', 'in', 'surprise', 'defeat', '.']
['NN', ':', 'NNP', 'VB', 'NNP', 'NNP', ',', 'NNP', 'IN', 'DT', 'NN', '.']
['NN', ':', 'NNP', 'NNP', 'NNP', 'NNP', ',', 'NNP', 'IN', 'NNP', 'NNP', '.']
accuracy: [True, True, True, False, True, True, True, True, True, False, False, True]
------------------------------------------------------
['nadim', 'ladki']
['NNP', 'NNP']
['NNP', 'SYM']
accuracy: [True, False]
------------------------------------------------------
['al-ain', ',', 'united', 'arab', 'emirates', '$UNK$']
['NNP', ',', 'NNP', 'NNP', 'NNPS', 'CD']
['NNP', ',', 'NNP', 'NNP', 'NNP', 'CD']
accuracy: [True, True, True, True, False, True]
------------------------------------------------------
['japan', 'began', 'the', 'defence', 'of', 'their', 'asian', 'cup', 'title', 'with', 'a', 'lucky', '2-1', 'win', 'against', 'syr