In [1]:
from model_v2 import build_lm_classifier_tagger_inference, LSTM_SAVED_STATE
from utils import get_batch_classifier_inference, clean_text_v4 as clean_text
import json
import tensorflow as tf
import numpy as np

In [2]:
def load_graph(frozen_graph_filename):
    # We load the protobuf file from the disk and parse it to retrieve the 
    # unserialized graph_def
    with tf.gfile.GFile(frozen_graph_filename, "rb") as f:
        graph_def = tf.GraphDef()
        graph_def.ParseFromString(f.read())

    # Then, we import the graph_def into a new Graph and returns it 
    with tf.Graph().as_default() as graph:
        # The name var will prefix every op/nodes in your graph
        # Since we load everything in a new graph, this is not needed
        tf.import_graph_def(graph_def, name="prefix")
    return graph

In [3]:
graph = load_graph("109_Bank/checkpoints/chatbot_frozen/frozen.pb-723")

In [5]:
inputs = graph.get_tensor_by_name('prefix/LanguageModel/fw_inputs:0')
seq_lens = graph.get_tensor_by_name('prefix/LanguageModel/seq_lens:0')
char_lens = graph.get_tensor_by_name('prefix/LanguageModel/fw_char_lens:0')
bptt = graph.get_tensor_by_name('prefix/LanguageModel/bptt:0')
predict_prob = graph.get_tensor_by_name('prefix/Classifier/Softmax:0')
predict_tags = graph.get_tensor_by_name('prefix/SequenceTagger/cond/Merge:0')
tag_score = graph.get_tensor_by_name('prefix/SequenceTagger/cond/Merge_1:0')

In [6]:
config = tf.ConfigProto()
config.gpu_options.allow_growth = True
session = tf.Session(config=config, graph=graph)

In [7]:
with open('109_Bank/word2idx.json', 'r') as inp:
    word2idx = json.load(inp)
with open('109_Bank/char2idx.json', 'r') as inp:
    char2idx = json.load(inp)
word2char = {w: [char2idx[c] for c in w] for w in word2idx}
with open('Bank/class2idx.json', 'r') as inp:
    class2idx = json.load(inp)
    idx2class = {i: w for w, i in class2idx.items()}
with open('Bank/tag2idx.json', 'r') as inp:
    tag2idx = json.load(inp)
    idx2tag = {i: w for w, i in tag2idx.items()}

In [10]:
def inference(texts, bsz=32):
    texts = [clean_text(x.strip()) for x in texts]
    texts = np.array([[word2char.get(w, word2char['<UNK>']) for w in sent] for sent in texts])
    results = []
    for chars, lens, cl in get_batch_classifier_inference(texts, bsz):
        res = session.run([predict_prob, predict_tags], feed_dict={
            inputs: chars, seq_lens: lens,
            char_lens: cl, bptt: 20
        })
        results.append(res)
    return results

In [11]:
inference(['Dịch vụ 4d secure là gì vậy'])

[[array([[2.6383190e-10, 4.1221051e-07, 2.5529372e-09, 4.4460098e-06,
          1.3986401e-10, 4.0912629e-09, 9.9999332e-01, 1.7166562e-06,
          1.2660516e-07, 1.8831863e-09, 1.0862812e-09, 4.7357931e-09,
          2.1025111e-08, 3.9238179e-10, 3.7833234e-09, 3.3429002e-09,
          1.2676178e-08, 1.2719380e-09, 4.1492676e-09]], dtype=float32),
  array([[0, 0, 0, 2, 2, 0, 0, 0, 0]], dtype=int32)]]