Skip to content

Commit

Permalink
0.1.1
Browse files Browse the repository at this point in the history
  • Loading branch information
ustcxmwu committed Mar 24, 2021
1 parent 71708b6 commit 9ce25a9
Show file tree
Hide file tree
Showing 4 changed files with 96 additions and 2 deletions.
2 changes: 1 addition & 1 deletion doc/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
author = 'Xiaomin Wu'

# The full version, including alpha/beta/rc tags
release = '0.1.0'
release = '0.1.1'


# -- General configuration ---------------------------------------------------
Expand Down
12 changes: 12 additions & 0 deletions nano/network/base_network.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
from abc import ABCMeta, abstractmethod

import torch.nn as nn


class BaseNetwork(nn.Module):

def __init__(self):
super().__init__()

def forward(self, x):
pass
83 changes: 82 additions & 1 deletion nano/utils/common_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,11 @@
import time
from collections import namedtuple

import tensorflow as tf
from tensorflow.python.platform import gfile
import torch.nn as nn
import os
from typing import List

func_call_dict = dict()
FuncCallInfo = namedtuple('FuncCallInfo', ['call_cnt', 'cur_call_time', 'avg_call_time'])
Expand Down Expand Up @@ -43,7 +47,7 @@ class ProfileUtils(object):
def update_call_info(func_info, call_time):
func_info = func_info._replace(
cur_call_time=call_time,
avg_call_time=(func_info.avg_call_time*func_info.call_cnt+call_time) / (func_info.call_cnt + 1),
avg_call_time=(func_info.avg_call_time * func_info.call_cnt + call_time) / (func_info.call_cnt + 1),
call_cnt=func_info.call_cnt + 1
)
return func_info
Expand All @@ -59,6 +63,7 @@ def timeit(**info):
Returns: None
"""

def func_wrapper(method):
def time_func(*args, **kwargs):
ts = time.time()
Expand All @@ -76,7 +81,83 @@ def time_func(*args, **kwargs):
return result

return time_func

return func_wrapper


class ModelUtils(object):
"""
Utils for model convert and validation
"""

@staticmethod
def freeze_tb_model_to_pb(input_checkpoint, output_pb_file, output_none_names):
if output_none_names is None:
output_none_names = ["ArgMax"]
saver = tf.compat.v1.train.import_meta_graph(input_checkpoint + ".meta", clear_devices=True)
graph = tf.compat.v1.get_default_graph()
input_graph_def = graph.as_graph_def()

with tf.compat.v1.Session() as sess:
saver.restore(sess, input_checkpoint)
for tensor in sess.graph.get_operations():
print(tensor.name, tensor.values())
output_graph_def = tf.compat.v1.graph_util.convert_variables_to_constants(
sess=sess,
input_graph_def=input_graph_def,
output_none_names=output_none_names
)

with tf.qfile.GFile(output_pb_file, "wb") as f:
f.write(output_graph_def.SerializeToString())
print("{} ops in the final graph.".format(len(output_graph_def.node)))
tf.compat.v1.reset_default_graph()

@staticmethod
def convert_pb_to_tflite(pb_file, lite_file, input_node_names: List[str] = None,
output_node_names: List[str] = None):
if input_node_names is None:
input_node_names = ["X0", "legal_actions"]
if output_node_names is None:
output_node_names = ["inference"]
converter = tf.contrib.lite.TocoConverter.from_frozen_file(pb_file, input_node_names, output_node_names)
tflite_model = converter.convert()
with open(lite_file, "wb") as f:
f.write(tflite_model)

@staticmethod
def get_tf_sess_from_pb(pb_file, write_tensor: bool = False):
sess = tf.compat.v1.Session()
with gfile.FastGFile(pb_file, 'rb') as f:
graph_def = tf.compat.v1.GraphDef()
graph_def.ParseFromString(f.read())
sess.graph.as_default()
tf.import_graph_def(graph_def, name='')

sess.run(tf.compat.v1.global_variables_initializer())

tensors = []
for tensor in tf.contrib.graph_editor.get_tensors(tf.compat.v1.get_default_graph()):
print(tensor)
if "read" in tensor.name or "reduction_indices" in tensor.name or "StopGradient" in tensor.name or \
"output" in tensor.name:
tensors.append(str(tensor) + "\n")

if write_tensor:
file = os.path.splitext(pb_file)[0] + ".tensor"
with open(file, mode='w') as f:
f.writelines(tensors)

return sess

@staticmethod
def inference_validate(input_name, legal_name, feed_dict, inference_name, expected_action, pb_file):
sess = ModelUtils.get_tf_sess_from_pb(pb_file)
input = sess.graph.get_tensor_by_name(input_name + ":0")
legal = sess.graph.get_tensor_by_name(legal_name + ":0")
inference = sess.graph.get_tensor_by_name(inference_name + ":0")
action = sess.run(inference, feed_dict={input: feed_dict[input_name], legal: feed_dict[legal_name]})
if action == expected_action:
print("test pass")

return action
1 change: 1 addition & 0 deletions nano/utils/logger_cls.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,3 +81,4 @@ def run(self):
publish.join()



0 comments on commit 9ce25a9

Please sign in to comment.