-
Notifications
You must be signed in to change notification settings - Fork 7
/
inference_wrapper.py
executable file
·135 lines (106 loc) · 4.75 KB
/
inference_wrapper.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
"""Model wrapper class for performing inference with a Text Detector."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os.path
import tensorflow as tf
import text_detector
class InferenceWrapper(object):
"""Model wapper class for performing inference with a Text Detector."""
def __init__(self):
self.model = None
def build_model(self, model_config):
"""Builds the model for inference.
Args:
model_config: Object containing configuration for building the
model.
Returns:
model: The model object.
"""
model = text_detector.TextDetector(model_config, mode="inference")
model.build()
self.model = model
return model
def inference_step(self, sess, encoded_image):
"""Runs one step of inference.
Args:
sess: TensorFlow Session object.
encoded_image: An encoded image string.
Returns:
proposals: A numpy array of shape [num_valid_anchors, 4]
containing the coordinates of text proposals.
# scores: A numpy array of shape [num_valid_anchors] containing the
confident of each text proposal.
offsets: A numpy array of shape [num_valid_anchors] containing the
horizontal offset of each text proposal.
"""
text_bboxes = sess.run(
fetches=["inference/text_bboxes:0",],
feed_dict={
"image_input/image_feed:0": encoded_image
})
return text_bboxes[0]
def _create_restore_fn(self, checkpoint_path, saver):
"""Creates a function that restores a model from checkpoint.
Args:
checkpoint_path: Checkpoint file or a directory containing a
checkpoint file.
saver: Saver for restoring variables from the checkpoint file.
Returns:
restore_fn: A function such that restore_fn(sess) loads model
variables from the checkpoint file.
Raises:
ValueError: If checkpoint_path does not refer to a checkpoint
file or a directory containing a checkpoint file.
"""
if tf.gfile.IsDirectory(checkpoint_path):
checkpoint_path = tf.train.latest_checkpoint(checkpoint_path)
if not checkpoint_path:
raise ValueError("No checkpoint file found in: %s"
% checkpoint_path)
def _restore_fn(sess):
tf.logging.info(
"Loading model from checkpoint: %s", checkpoint_path)
saver.restore(sess, checkpoint_path)
tf.logging.info("Successfully loadded checkpoint: %s",
os.path.basename(checkpoint_path))
return _restore_fn
def build_graph_from_config(self, model_config, checkpoint_path):
"""Builds the inference graph from a configuration object.
Args:
model_config: Object containing configuration for building the model
checkpoint_path: Checkpoint file or a directory containing a
checkpoint file
Returns:
restore_fn: A function such that restore_fn(sess) loads model
variables from the checkpoint file.
"""
tf.logging.info("Building model.")
self.build_model(model_config)
saver = tf.train.Saver()
return self._create_restore_fn(checkpoint_path, saver)
def build_graph_from_proto(self, graph_def_file, saver_def_file,
checkpoint_path):
"""Builds the inference graph from serialized GraphDef and SaverDef protos.
Args:
graph_def_file: File containing a serialized GraphDef proto
saver_def_file: File containing a serialized SaverDef proto
checkpoint_path: Checkpoint file or a directory containing a
checkpoint file
Returns:
restore_fn: A function such that restore_fn(sess) loads model
variables from the checkpoint file.
"""
# Load the Graph.
tf.logging.info("Loading GraphDef from file: %s" % graph_def_file)
graph_def = tf.GraphDef()
with tf.gfile.FastGFile(graph_def_file, "rb") as f:
graph_def.ParseFromString(f.read())
tf.import_graph_def(graph_def, name="")
# Load the Saver.
tf.logging.info("Loading SaverDef from file: %s", saver_def_file)
saver_def = tf.train.SaverDef()
with tf.gfile.FastGFile(saver_def_file, "rb") as f:
saver_def.ParseFromString(f.read())
saver = tf.train.Saver(saver_def=saver_def)
return self._create_restore_fn(checkpoint_path, saver)