In [1]:
# coding: utf-8

In [2]:
import os
import sys
import time
from datetime import timedelta

import numpy as np
import tensorflow as tf
from sklearn import metrics

from modelOneConv import *
from dataOwn import *
import pickle


In [3]:
class CnnKMaxPoolModel:
    def __init__(self):
        
        self.config = TCNN_K_Config()
        self.vocabulary_word2index, self.vocabulary_index2word  = create_voabulary(self.config.word2vec_model_path)
        self.config.vocab_size = len(self.vocabulary_word2index)+1
        #这里通过实际的word2vec模型统计词典中词的数量，赋值到config中，然后加载CNN模型
        self.model = CNN_K_MAXPOOL_DISEASE(self.config)
        print(self.config.vocab_size)
        self.word2vecModel = Word2Vec.load(self.config.word2vec_model_path)
    
        self.save_dir = 'checkpoints/textrnn'
        self.save_path = os.path.join(self.save_dir, 'best_validation')  # 最佳验证结果保存路径
        
        self.session = tf.Session()
        self.session.run(tf.global_variables_initializer())
        self.saver = tf.train.Saver()
        #if not os.path.exists(save_dir):
        #    os.makedirs(save_dir)
        self.saver.restore(sess=self.session, save_path=self.save_path)  # 读取保存的模型
        
    def get_time_dif(self,start_time):
        """获取已使用时间"""
        end_time = time.time()
        time_dif = end_time - start_time
        return timedelta(seconds=int(round(time_dif)))


    def feed_data(self,x_batch, y_batch, keep_prob):
        feed_dict = {
            self.model.input_x: x_batch,
            self.model.input_y: y_batch,
            self.model.dropout_keep_prob: keep_prob
        }
        return feed_dict


    def evaluate(self,sess, x, y):
        """评估在某一数据上的准确率和损失"""
        #batch_eval = batch_iter(x_, y_, config.batch_size)
        batch_size = self.config.batch_size
        total_loss = 0.0
        total_acc = 0.0
        eval_out = []
        data_len = len(x)
        num_batch = int((data_len - 1) / batch_size) + 1

        indices = np.random.permutation(np.arange(data_len))
        print("data_len=", data_len, " num_batch=", num_batch)
        x = np.array(x)
        y = np.array(y)
        x_shuffle = x[indices]
        y_shuffle = y[indices]

        for i in range(num_batch):
            start_id = i * batch_size
            end_id = min((i + 1) * batch_size, data_len)
            x_batch, y_batch =  x_shuffle[start_id:end_id], y_shuffle[start_id:end_id]

            batch_len = len(x_batch)
            feed_dict = self.feed_data(x_batch, y_batch, 1.0)
            loss, acc, y_pred_cls = sess.run([self.model.loss, self.model.acc, self.model.y_pred_cls], feed_dict=feed_dict)
            total_loss += loss * batch_len
            total_acc += acc * batch_len
            eval_out = np.concatenate([eval_out, y_pred_cls])

        return total_loss / data_len, total_acc / data_len, eval_out, y_shuffle
    def test(self, x_val, y_val):

        start_time = time.time()
        print('Testing...')
        total_loss, total_acc, eval_out, y_real = self.evaluate(self.session, x_val, y_val)
        print("total_loss=", total_loss, " total_acc=", total_acc)
        print(metrics.classification_report(y_real,eval_out))

        cm = metrics.confusion_matrix(y_real, eval_out)
        print("cm====\n", cm)  
    def predictBatch(self, test_data_path, test_label_path):
        x_val, y_val = loadTrainOrTest_data_oneLabel(test_data_path, test_label_path, self.vocabulary_word2index)
        x_val = pad_sequences(x_val, self.config.sentence_length)  # padding to max length   \
        self.test(x_val, y_val)

In [4]:
model = CnnKMaxPoolModel()

weight shape: [7, 100, 1, 6]
weight shape: [1200, 100]
weight shape: [100, 14]
input shape0= (?, 50)
sent_embed shape= Tensor("inference/embedding_lookup:0", shape=(?, 50, 100), dtype=float32)
input shape: (?, 50, 100, 1)
input_unstack shape: 100
conv1-con shape= (?, 50, 100, 6)
conv1-kemax-pool shape= (?, 50, 100, 6)
trained shape= (?, 1200)
out shape= (?, 14)
232015


NameError: name 'save_path' is not defined

In [None]:
#输入测试文件地址
test_data_path = "../cnnDatasets/testAdmin.feature"
test_label_path = "../cnnDatasets/testAdmin.label"
model.predictBatch(test_data_path, test_label_path)