In [4]:
import tensorflow as tf
import numpy as np
import gensim

In [10]:
glove_model = gensim.models.KeyedVectors.load('glove.model')

In [52]:
class ConvRecurrentAttentionNet:
    def __init__(self, embedding_model, seq_length, embedding_dim, filter_size, num_filters, hidden_size, batch_size, num_classes):
        self.embedding_model = embedding_model
        self.seq_length = seq_length
        self.embedding_dim = embedding_dim
        self.filter_size = filter_size
        self.num_filters = num_filters
        self.hidden_size = hidden_size
        self.batch_size = batch_size
        self.num_classes = num_classes
        
        
    def word_embedding(self, inputs, reuse=False):
        with tf.variable_scope('word_embedding', reuse=reuse):
            embedding_W = tf.get_variable('embedding_W',
                                          shape=[self.embedding_model.vectors.shape[0], self.embedding_model.vectors.shape[1]],
                                          initializer=tf.constant_initializer(self.embedding_model.vectors),
                                          trainable=True)
            
            embedded_X = tf.nn.embedding_lookup(embedding_W, inputs)
            
        return embedded_X
    

    def attention_extraction(self, inputs, reuse=False):
        with tf.variable_scope('attention_extraction', reuse=reuse):
            input_reshaped = tf.reshape(inputs, [-1, self.seq_length, self.embedding_dim, 1])
            paddings = tf.constant([[0, 0], [1, 1], [0, 0], [0, 0]], dtype='int32')
            input_padded = tf.pad(input_reshaped, paddings, 'CONSTANT')
            
            filter_shape = [self.filter_size, self.embedding_dim, 1, self.num_filters]
            
            W = tf.get_variable('cnn_filter', 
                                shape=filter_shape, 
                                initializer=tf.truncated_normal_initializer(stddev=0.1))
            
            b = tf.get_variable('cnn_bias',
                                shape=[num_filters],
                                initializer=tf.constant_initializer(0.0))
            
            conv = tf.nn.conv2d(input=input_padded, 
                                filter=W, 
                                strides=[1, 1, 1, 1], 
                                padding='VALID',
                                name='convolution')
            
            conv_output = tf.nn.relu(tf.nn.bias_add(conv, b), name='relu')
            attention_signal = tf.reduce_mean(tf.transpose(tf.squeeze(conv_output), perm=[0, 2, 1]), 
                                              axis=1, 
                                              name='attention_signal')
        
        return attention_signal
    
    
    def lstm_encoder(self, inputs, reuse=False):
        with tf.variable_scope('lstm_encoder', reuse=reuse):
            lstm_cell = tf.nn.rnn_cell.BasicLSTMCell(self.hidden_size)
            lstm_cell = tf.nn.rnn_cell.DropoutWrapper(lstm_cell, output_keep_prob=0.5)
            hiddens, _ = tf.nn.dynamic_rnn(lstm_cell, inputs, dtype=tf.float32)
        
        return hiddens
    
    
    def classification(self, attention_signal, hiddens, reuse=False):
        with tf.variable_scope('classification', reuse=reuse):    
            whole_seq = tf.reduce_mean(tf.multiply(hiddens, tf.expand_dims(attention_signal, axis=2)), axis=1)
            
            W = tf.get_variable('output_weights', 
                                shape=[self.embedding_dim, self.num_classes], 
                                initializer=tf.truncated_normal_initializer(stddev=0.1))
            
            b = tf.get_variable('output_bias',
                                shape=[self.num_classes],
                                initializer=tf.constant_initializer(0.0))
            
            output = tf.nn.bias_add(tf.matmul(whole_seq, W), b)
            
        return output
    
    
    def build_graph(self, inputs, reuse=False):
        embedded_X = self.word_embedding(inputs, reuse=reuse)
        attention_signal = self.attention_extraction(embedded_X, reuse=reuse)
        hiddens = self.lstm_encoder(embedded_X, reuse=reuse)
        output = self.classification(attention_signal, hiddens, reuse=reuse)
        
        return output

In [41]:
def parser(serialized_example):
    features = {
        'document': tf.FixedLenFeature([18], tf.int64),
        'label': tf.FixedLenFeature([6], tf.int64)
    }

    parsed_feature = tf.parse_single_example(serialized_example, features)

    document = parsed_feature['document']
    label = parsed_feature['label']

    return document, label

def read_tfrecord(fname, parser, shuffle_size, batch_size, seq_length):
    dataset = tf.data.TFRecordDataset(fname).map(parser)
    dataset = dataset.shuffle(shuffle_size, reshuffle_each_iteration=True)
    dataset = dataset.batch(batch_size)
    iterator = tf.data.Iterator.from_structure(dataset.output_types, dataset.output_shapes)
    feature, label = iterator.get_next()
    feature = tf.reshape(feature, [-1, seq_length])
    feature = tf.cast(feature, tf.int64)
    return iterator, dataset, feature, label

In [57]:
seq_length = 18
embedding_dim = 100
filter_size = 3
num_filters = 50
hidden_size = 100
batch_size = 16
num_classes = 6

train_data_dir = './train.tfrecord'
shuffle_size = 100000

learning_rate = 0.001

In [49]:
# embedding_model, seq_length, embedding_dim, filter_size, num_of_filters, hidden_size, batch_size, class_num

model = ConvRecurrentAttentionNet(glove_model, seq_length, embedding_dim, filter_size, num_of_filters, hidden_size, batch_size, class_num)

In [58]:
tf.reset_default_graph()

train_itr, train_dataset, train_X, train_y = read_tfrecord(train_data_dir, 
                                                           parser, 
                                                           shuffle_size, 
                                                           batch_size, 
                                                           seq_length)

train_init_op = train_itr.make_initializer(train_dataset)

X = tf.placeholder(tf.int64, [None, seq_length])
Y = tf.placeholder(tf.float32, [None, num_classes])

logits = model.build_graph(X)

global_step = tf.Variable(0, trainable=False, name='global_step')

loss = tf.nn.sigmoid_cross_entropy_with_logits(labels=Y, logits=logits, name='loss')

optimizer = tf.train.AdamOptimizer(learning_rate).minimize(loss, global_step=global_step)

predict_proba = tf.nn.sigmoid(logits)

auc = tf.metrics.auc(Y, predict_proba)

In [59]:
epochs = 1

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    sess.run(tf.local_variables_initializer())
    
    writer = tf.summary.FileWriter('./graphs', sess.graph)
    
    for epoch in range(epochs):
        sess.run(train_init_op)
        auc_list = []
        
        while True:
            try:
                step = sess.run(global_step)
                
                _X, _Y = sess.run([train_X, train_y])
                _, _loss, _auc = sess.run([optimizer, loss, auc], feed_dict = {X: _X, Y: _Y})
                auc_list.append(_auc)
                
                if (step > 0) and (step % 500 == 0):
                    print('Step: {}, Cost: {}'.format(step, _loss))
                    
            except tf.errors.OutOfRangeError:
                break
                    
#         print(result)

Step: 500, Cost: [[0.6931644  0.6931571  0.6931575  0.6931505  0.6931559  0.6931586 ]
 [0.6931612  0.6931513  0.69315153 0.69315094 0.6931548  0.6931571 ]
 [0.6931533  0.69315267 0.693155   0.69315344 0.693157   0.693154  ]
 [0.6931584  0.6931519  0.6931508  0.6931638  0.6931548  0.69315904]
 [0.69315934 0.6931547  0.6931505  0.69315493 0.6931546  0.6931602 ]
 [0.6931578  0.6931536  0.69315267 0.6931545  0.69315743 0.69315857]
 [0.69316405 0.6931597  0.69315684 0.69316465 0.69315493 0.6931662 ]
 [0.6931587  0.69315255 0.69315773 0.69315356 0.6931595  0.69315827]
 [0.6931558  0.6931564  0.69315046 0.6931515  0.69315207 0.6931719 ]
 [0.6931532  0.69315284 0.69315165 0.69315165 0.6931568  0.69317317]
 [0.69313    0.6931415  0.6931439  0.69315296 0.6931383  0.6931609 ]
 [0.6931324  0.6931556  0.6931423  0.6931548  0.69314003 0.6931612 ]
 [0.69315946 0.6931591  0.69315475 0.6931534  0.6931557  0.69316214]
 [0.69315517 0.69315034 0.6931549  0.69315296 0.6931546  0.6931569 ]
 [0.6931623  0.69

KeyboardInterrupt: 