/
make_checkpoint.py
66 lines (55 loc) · 2.41 KB
/
make_checkpoint.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
## make_checkpoint.py -- convert a .pb to a TF session dump
##
## Copyright (C) 2017, Nicholas Carlini <nicholas@carlini.com>.
##
## This program is licenced under the BSD 2-Clause licence,
## contained in the LICENCE file in this directory.
from tensorflow.core.framework.graph_pb2 import *
import numpy as np
import tensorflow as tf
import sys
sys.path.append("DeepSpeech")
from util.audio import audiofile_to_input_vector
from util.text import ctc_label_dense_to_sparse
# Okay, so this is ugly. We don't want DeepSpeech to crash
# when we haven't built the language model.
# So we're just going to monkeypatch TF and make it a no-op.
# Sue me.
tf.load_op_library = lambda x: x
import DeepSpeech as DeepSpeech
graph_def = GraphDef()
loaded = graph_def.ParseFromString(open("models/output_graph.pb","rb").read())
with tf.Graph().as_default() as graph:
new_input = tf.placeholder(tf.float32, [None, None, None],
name="new_input")
# Load the saved .pb into the current graph to let us grab
# access to the weights.
logits, = tf.import_graph_def(
graph_def,
input_map={"input_node:0": new_input},
return_elements=['logits:0'],
name="newname",
op_dict=None,
producer_op_list=None
)
# Now let's dump these weights into a new copy of the network.
with tf.Session(graph=graph) as sess:
# Sample sentetnce, to make sure we've done it right
mfcc = audiofile_to_input_vector("sample_input.wav", 26, 9)
# Okay, so this is ugly again.
# We just want it to not crash.
tf.app.flags.FLAGS.alphabet_config_path = "DeepSpeech/data/alphabet.txt"
DeepSpeech.initialize_globals()
logits2 = DeepSpeech.BiRNN(new_input, [len(mfcc)], [0]*10)
# Here's where all the work happens. Copy the variables
# over from the .pb to the session object.
for var in tf.global_variables():
sess.run(var.assign(sess.run('newname/'+var.name)))
# Test to make sure we did it right.
res = (sess.run(logits, {new_input: [mfcc],
'newname/input_lengths:0': [len(mfcc)]}).flatten())
res2 = (sess.run(logits2, {new_input: [mfcc]})).flatten()
print('This value should be small',np.sum(np.abs(res-res2)))
# And finally save the constructed session.
saver = tf.train.Saver()
saver.save(sess, "models/session_dump")