In [1]:
import tensorflow as tf
import numpy as np

### 定义TextRNN结构，使用双向的LSTM

In [2]:
class TextRNN:
    def __init__(self, batch_size, num_classes, vocab_size, sentence_len, embed_size, 
                 learning_rate, decay_steps, decay_rate, is_training):
        #1.定义超参数
        self.batch_size = batch_size
        self.num_classes = num_classes
        self.vocab_size = vocab_size
        self.sentence_len = sentence_len
        self.embed_size = embed_size
        self.hidden_size = embed_size #lstm层的维度
        self.learning_rate = learning_rate
        self.is_training = is_training
        self.initializer = tf.random_normal_initializer(stddev=0.1)
        
        #epoch信息
        self.global_step = tf.Variable(0, trainable=False, name='global_step')
        self.epoch_step = tf.Variable(0, trainable=False, name='epoch_step')
        self.epoch_increment = tf.assign(self.epoch_step, tf.add(self.epoch_step, tf.constant(1)))
        self.decay_steps, self.decay_rate = decay_steps, decay_rate
        
        #2.输入
        self.input_x = tf.placeholder(tf.int32, [None, sentence_len], 'input_x')
#         self.input_y = tf.placeholder(tf.int32, [None], 'input_y') #单个标签
        self.input_y = tf.placeholder(tf.float32, [None, num_classes], 'input_y') #多个标签
        self.dropout_keep_prob = tf.placeholder(tf.float32, name='dropout_keep_prob')
        
        #3.初始化全连接层参数
        self.init_weight()
        
        #4.网络结构
        self.logits = self.inference() #[batch_size, num_classes]
        
        #5.损失函数
        self.loss_val = self.loss()
        
        #6.优化器
        self.train_op = self.train()
        
    def init_weight(self):
        self.Embedding = tf.get_variable('Embedding', [self.vocab_size, self.embed_size], dtype=tf.float32)
        self.W = tf.get_variable('W', [self.hidden_size * 2, self.num_classes], dtype=tf.float32) #双向LSTM，输出concat，所以此处为2倍
        self.b = tf.get_variable('b', [self.num_classes], dtype=tf.float32)
        
    def inference(self):
        # a.embedding
        self.sentence_embed = tf.nn.embedding_lookup(self.Embedding, self.input_x) #[batch_size, sentence_len, embed_size]
        
        # b.bidiretional lstm
        self.fw_cell = tf.contrib.rnn.BasicLSTMCell(self.hidden_size) #前向单元
        self.bw_cell = tf.contrib.rnn.BasicLSTMCell(self.hidden_size) #后向单元
#         if self.dropout_keep_prob is not None:
#             self.fw_cell = tf.contrib.rnn.DropoutWrapper(self.fw_cell, output_keep_prob=) 
                #input_keep_prob是对输入而言，output_keep_prb是对lstm各层而言
        outputs, _ = tf.nn.bidirectional_dynamic_rnn(self.fw_cell, self.bw_cell, self.sentence_embed, dtype=tf.float32)
        #输入为 [batch_size, sentence_len, embed_size]，输出为大小为2的元组，每个元素为[batch_size, sentence_len, hidden_size]
        
        # c.concat
        fw_output = outputs[0][:,-1,:]
        bw_output = outputs[1][:,-1,:] #[batch_size, 1, hidden_size]
        final_output = tf.concat([fw_output, bw_output], axis=1) #[batch_size, 1, hidden_size*2]
        final_output = tf.reshape(final_output, [-1, self.hidden_size*2]) #[batch_size, hidden_size * 2]
        
        # d.full_connection
        logits = tf.matmul(final_output, self.W) + self.b #[batch_size, num_classes]
        return logits
    
    def loss(self, l2_lambda=0.0001):
#         loss1 = tf.nn.sparse_softmax_cross_entropy_with_logits(labels=self.input_y, logits=self.logits)
        #先将label转化为one-hot形式，再对logits计算softmax，最后计算交叉熵
        loss1 = tf.nn.sigmoid_cross_entropy_with_logits(labels=self.input_y, logits=self.logits)
        loss1 = tf.reduce_mean(tf.reduce_sum(loss1, axis=1))
        loss2 = tf.add_n([tf.nn.l2_loss(v) for v in tf.trainable_variables()]) * l2_lambda
        return loss1 + loss2
    
    def train(self):
        learning_rate = tf.train.exponential_decay(self.learning_rate, self.global_step, self.decay_steps, self.decay_rate, staircase=True)
        train_op = tf.contrib.layers.optimize_loss(self.loss_val, self.global_step, learning_rate, optimizer='Adam')
        return train_op
    

### 测试

In [3]:
def test():
    num_classes=5
    learning_rate=0.01
    batch_size=5
    decay_step=1000
    decay_rate=0.9
    sequence_length=5
    vocab_size=10000
    embed_size=100
    is_training=True
    dropout_keep_prob=0.5
    
    model = TextRNN(batch_size, num_classes, vocab_size, sequence_length, embed_size, 
                     learning_rate, decay_step, decay_rate, True)
    print(tf.trainable_variables())
    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())
        input_x = np.random.randint(0,100,size=(batch_size, sequence_length),dtype=np.int32)
        input_y = np.random.randint(0, 2,size=(batch_size, num_classes), dtype=np.int32)
        for i in range(20):
            #input_x = np.zeros((batch_size, sequence_length), dtype=np.int32)
            #input_y = np.array([1,0,1,1,1,2,1,1], dtype=np.int32)
            loss, logits, _ = sess.run([model.loss_val, model.logits, model.train_op],
                                            feed_dict={model.input_x: input_x, model.input_y: input_y,
                                                       model.dropout_keep_prob: dropout_keep_prob})
            logits = np.argsort(logits)
            print('****label****\n', input_y)
            print('****pre_y****\n', logits, '\n')

In [4]:
tf.reset_default_graph()
test()

[<tf.Variable 'Embedding:0' shape=(10000, 100) dtype=float32_ref>, <tf.Variable 'W:0' shape=(200, 5) dtype=float32_ref>, <tf.Variable 'b:0' shape=(5,) dtype=float32_ref>, <tf.Variable 'bidirectional_rnn/fw/basic_lstm_cell/kernel:0' shape=(200, 400) dtype=float32_ref>, <tf.Variable 'bidirectional_rnn/fw/basic_lstm_cell/bias:0' shape=(400,) dtype=float32_ref>, <tf.Variable 'bidirectional_rnn/bw/basic_lstm_cell/kernel:0' shape=(200, 400) dtype=float32_ref>, <tf.Variable 'bidirectional_rnn/bw/basic_lstm_cell/bias:0' shape=(400,) dtype=float32_ref>]
****label****
 [[0 0 1 1 0]
 [0 1 0 1 0]
 [0 0 0 1 0]
 [0 0 0 0 1]
 [1 1 0 1 0]]
****pre_y****
 [[4 2 0 3 1]
 [4 2 0 3 1]
 [4 2 0 3 1]
 [4 2 0 3 1]
 [4 2 0 3 1]] 

****label****
 [[0 0 1 1 0]
 [0 1 0 1 0]
 [0 0 0 1 0]
 [0 0 0 0 1]
 [1 1 0 1 0]]
****pre_y****
 [[4 2 0 3 1]
 [4 2 0 3 1]
 [4 2 0 3 1]
 [4 2 0 3 1]
 [4 2 0 3 1]] 

****label****
 [[0 0 1 1 0]
 [0 1 0 1 0]
 [0 0 0 1 0]
 [0 0 0 0 1]
 [1 1 0 1 0]]
****pre_y****
 [[4 2 0 1 3]
 [4 2 0 1 3]

### 训练

In [5]:
import sys
import tensorflow as tf
import numpy as np
# from tflearn.data_utils import to_categorical, pad_sequences
import os
import pickle
import h5py

In [6]:
#定义超参数
FLAGS = tf.app.flags.FLAGS
tf.app.flags.DEFINE_integer('batch_size', 1024, 'batch_size')
tf.app.flags.DEFINE_integer('num_classes', 1999, 'num_classes')
tf.app.flags.DEFINE_integer('sentence_len', 100, 'length of each sentence')
tf.app.flags.DEFINE_integer('embed_size', 100, 'embedding size')
tf.app.flags.DEFINE_float('learning_rate', 0.01, '')
tf.app.flags.DEFINE_float('decay_rate', 0.8, '')
tf.app.flags.DEFINE_integer('decay_steps', 1000, 'number of steps before decay learning rate')
tf.app.flags.DEFINE_bool('is_training', True, '')

tf.app.flags.DEFINE_integer('num_epoch', 10, 'number of epoch')

tf.app.flags.DEFINE_string("ckpt_dir","testrnn_multilabel_checkpoint/","checkpoint location for the model")
tf.app.flags.DEFINE_string("cache_path","textrnn_multilabel_checkpoint/data_cache.pik","data chche for the model")

In [7]:
import time
def log(str):
    t = time.localtime()
    print("[%4d/%02d/%02d %02d:%02d:%02d]"%(t.tm_year, t.tm_mon, t.tm_mday, t.tm_hour, t.tm_min, t.tm_sec), end=' ')
    print(str)

In [11]:
def main(_):
    #1.加载数据
    base_path = '/data/chenhy/data/ieee_zhihu_cup/'
    cache_file_h5py = base_path + 'data.h5'
    cache_file_pickle = base_path + 'vocab_label.pik'
    word2index,label2index,train_X,train_y,vaild_X,valid_y,test_X,test_y = load_data(cache_file_h5py, cache_file_pickle)
    vocab_size = len(word2index)
    index2word = {index: word for word, index in word2index.items()}
    index2label = {index: label for label, index in label2index.items()}

    print("train_X[0:5]:", train_X[0:5])
    print("train_Y[0:5]:", train_y[0:5])
    train_y_short = get_target_label_short(train_y[0])
    print("train_y_short:", train_y_short)
    
    #2.创建session
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    with tf.Session(config=config) as sess:
        model = TextRNN(FLAGS.batch_size, FLAGS.num_classes, vocab_size, FLAGS.sentence_len,FLAGS.embed_size, 
                        FLAGS.learning_rate, FLAGS.decay_steps, FLAGS.decay_rate, FLAGS.is_training)
        saver = tf.train.Saver()
        batch_size = FLAGS.batch_size
        if os.path.exists(FLAGS.ckpt_dir + 'checkpoint'):
            log("restore from checkpoint")
            saver.restore(sess, tf.train.latest_checkpoint(FLAGS.ckpt_dir))
        else:
            log('init variables')
            sess.run(tf.global_variables_initializer())
#             #是否使用embedding
#             print('assign pre-trained embedding')
#             embedding_assign = tf.assign(model.Embedding, tf.constant(np.array(embedding_final))) #为model.Embedding赋值
#             sess.run(embedding_assign)
            num_of_data = len(train_y)
            for _ in range(FLAGS.num_epoch):
                epoch = sess.run(model.epoch_step)
                loss, counter = 0., 0.
                for start, end in zip(range(0, num_of_data, batch_size), range(batch_size, num_of_data, batch_size)):
                    loss_tmp,  _ = sess.run([model.loss_val, model.train_op], 
                                                    feed_dict={model.input_x: train_X[start:end, :FLAGS.sentence_len], model.input_y: train_y[start:end],
                                                               model.dropout_keep_prob: 1})
                    loss, counter = loss + loss_tmp, counter + 1
                    if counter % 200 == 0:
                        log("Epoch %d\Batch %d\ Train Loss:%.3f"%(epoch, counter, loss/float(counter)))
                    if counter % 3000 == 0:
                        print('run model on validation data...')
                        loss_valid, f1_score, precision, recall = do_eval(sess, model, vaild_X, valid_y)
                        log("Epoch %d/ Validation Loss:%.3f/ F1_score:%.3f/ Precision:%.3f/ Recall:%.3f"%(epoch, loss_valid, f1_score, precision, recall))
                        #save the checkpoint
                        save_path = FLAGS.ckpt_dir + 'model.ckpt'
                        saver.save(sess, save_path, global_step=model.epoch_step)
                sess.run(model.epoch_increment)
            
    

In [12]:
#一些辅助函数
def get_target_label_short(eval_y):
    res = [idx for idx in range(len(eval_y)) if eval_y[idx] > 0] #结果如：[45,100,1555]
    return res

def get_label_using_logits(logits, top_number=5):
    predict_y = [idx for idx in range(len(logits)) if logits[idx] >= 0.5]
    if len(predict_y) == 0: predict_y = [np.argmax(logits)]
    return predict_y

def load_data(h5_file_path, pik_file_path):
    if not os.path.exists(h5_file_path) or not os.path.exists(pik_file_path):
        raise RuntimeError('No such file!!')

    print('cache files exist, going to load in...')
    print('loading h5_file...')
    h5_file = h5py.File(h5_file_path, 'r')
    print('h5_file.keys:', h5_file.keys())
    train_X, train_y = h5_file['train_X'], h5_file['train_Y']
    vaild_X, valid_y = h5_file['vaild_X'], h5_file['valid_Y']
    test_X,  test_y  = h5_file['test_X'],  h5_file['test_Y']
    #embedding_final = h5_file['embedding']

    print('loading pickle file')
    word2index, label2index = None, None
    with open(pik_file_path, 'rb') as pkl:
        word2index,label2index = pickle.load(pkl)
    print('cache files load successful!')
    return word2index,label2index,train_X,train_y,vaild_X,valid_y,test_X,test_y

def do_eval(sess, model, eval_X, eval_y):
    test_X, test_y = eval_X[:3000], eval_y[:3000]
    num_of_data = len(test_y)
    batch_size = 1
    loss, F1, p, r = 0., 0., 0., 0.
    label_dict_confuse = {'TP':0.000001, 'FN':0.000001, 'FP':0.000001}
    for start in range(num_of_data):
        end = start + 1
        lo,logits = sess.run([model.loss_val, model.logits], 
                        feed_dict={model.input_x: test_X[start:end,:FLAGS.sentence_len], model.input_y: test_y[start:end],
                                   model.dropout_keep_prob:1.0})
        loss += lo
        pre = get_label_using_logits(logits[0])
        label = get_target_label_short(test_y[start])
#         pre = np.argsort(logits[0])[-5:]
#         label = [i for i in range(len(test_y[start])) if test_y[start][i] > 0]
        if start == 0: print('label:',label, 'predict:', pre)
        inter = len([x for x in pre if x in label])
        label_dict_confuse['TP'] += inter
        label_dict_confuse['FN'] += len(label) - inter
        label_dict_confuse['FP'] += len(pre) - inter
    print(label_dict_confuse)
    p = float(label_dict_confuse['TP'])/(label_dict_confuse['TP']+label_dict_confuse['FP'])
    r = float(label_dict_confuse['TP'])/(label_dict_confuse['TP']+label_dict_confuse['FN'])
    if p + r == 0: return loss/num_of_data, 0, 0, 0
    F1 = (2 * p * r)/(p + r)
    return loss/num_of_data, F1, p, r


In [None]:
tf.reset_default_graph()
tf.app.run()

cache files exist, going to load in...
loading h5_file...
h5_file.keys: KeysView(<HDF5 file "data.h5" (mode r)>)
loading pickle file
cache files load successful!
train_X[0:5]: [[ 832   60  256 1172 3407  516   96  138  103 1108   16    3   96  177
    22   11  672   53   18 1560 1560   15   65   12  180   10  342  173
    13  103  141  707  191   12  342  173   15   13   22   11  229  264
   163 1362  135 1249  156  156  731  115   84   10  808 1713  103  141
   229  264  788  421  103  141   12   95  316   10  808 1713  103  141
    12 2413 1227   15 1397  997   22  116  301  489   12   18  858   99
   596   98   26  646  813   10  386 1093  197  767   22   11 1179 1849
   593   84   22   11   94  102  322   60  190  220  583  355   10  153
   103  141  192  153   12   10 1244  466  116  103  141  189   58  130
    95  316   12  788  421  699   16    0    0    0    0    0    0    0
     0    0    0    0    0    0    0    0    0    0    0    0    0    0
     0    0    0    0    0    0 