/
inference_tf.py
79 lines (62 loc) · 2.22 KB
/
inference_tf.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
"""
create a py file for Tensorflow Inference
"""
# coding: utf-8
import argparse
parser = argparse.ArgumentParser(description='input selfDefinedFileName')
parser.add_argument('--selfDefinedFileName','-n', type=str,
help='an self-defined name for an inference')
args = parser.parse_args()
selfDefinedFileName=args.selfDefinedFileName
str_ = """
import tensorflow as tf
import PIL.Image as Image
import numpy as np
from imagenet1000_clsid_to_human import cls_dict
from keras_param import models , shapes
"""
str_ += """
def get_model(model):
if model== "inception_v3":
from gen_code.{}_tensorflow_inception_v3 import KitModel
elif model== "vgg16":
from gen_code.{}_tensorflow_vgg16 import KitModel
elif model== "vgg19":
from gen_code.{}_tensorflow_vgg16 import KitModel
elif model== "resnet":
from gen_code.{}_tensorflow_resnet import KitModel
elif model== "mobilenet":
from gen_code.{}_tensorflow_mobilenet import KitModel
elif model== "xception":
from gen_code.{}_tensorflow_xception import KitModel
else:
return""".format(*[selfDefinedFileName]*6)
str_ += """
npy_path="gen_pb_json_npy/{}_%s.npy" % model
ckpt_path='gen_model/{}_tf_%s.ckpt' % model
return KitModel, npy_path, ckpt_path
""".format(*[selfDefinedFileName]*2)
str_ += """
def inference( img_path='cat1.jpeg', model="inception_v3" ):
KitModel, npy_path, ckpt_path = get_model(model)
inp, oup = KitModel(npy_path)
sess = tf.Session()
saver = tf.train.Saver()
saver.restore(sess, ckpt_path)
im = np.array(Image.open( img_path ).resize(shapes[model]))/float(255)
img = np.expand_dims(im,0) # just one fig
oup_ = sess.run(oup, feed_dict={inp:img})
oup_ = oup_.squeeze()
obj_idx = np.argmax(oup_)
idxs = np.argsort(oup_.squeeze())[-5:]
str_ = ""
for i in reversed(idxs):
str_ += "{:05.2f}% : {}\\n".format(oup_[i]*100, cls_dict[i])
sess.run(tf.global_variables_initializer()) #
tf.reset_default_graph()
print ("Model: {}, ImgPath: {}".format(model, img_path))
#print (str_)
return str_"""
with open("./{}_inference_tf.py".format(selfDefinedFileName), "wb") as f:
f.write(str_)
f.close()