# 模型转化

In [1]:
import matplotlib as mpl
import matplotlib.pyplot as plt
%matplotlib inline
import numpy as np
import sklearn
import pandas as pd
import os
import sys
import time
import tensorflow as tf
from tensorflow import keras

print(tf.__version__)
print(sys.version_info)
for module in mpl,np,pd,sklearn,tf,keras:
    print(module.__name__,module.__version__)

2.0.0
sys.version_info(major=3, minor=6, micro=10, releaselevel='final', serial=0)
matplotlib 3.1.2
numpy 1.18.1
pandas 0.25.3
sklearn 0.22.1
tensorflow 2.0.0
tensorflow_core.keras 2.2.4-tf


In [2]:
gpus = tf.config.experimental.list_physical_devices('GPU')
for gpu in gpus:
    tf.config.experimental.set_memory_growth(gpu, True)

## 签名函数转SavedModel
### 签名函数和具体函数的转化

In [3]:
# 签名函数
@tf.function(input_signature=[tf.TensorSpec([None], tf.int32, name='x')])
def cube(z):
    return tf.pow(z, 3)

In [4]:
# get_concrete_function 获取具体函数
cube_func_int32 = cube.get_concrete_function(tf.TensorSpec([None], tf.int32))
print(cube_func_int32)

<tensorflow.python.eager.function.ConcreteFunction object at 0x000001E885796630>


In [5]:
# 判断两个函数的签名是否一样
print(cube_func_int32 is cube.get_concrete_function(tf.TensorSpec([5], tf.int32)))
print(cube_func_int32 is cube.get_concrete_function(tf.constant([1, 2, 3])))

True
True


In [6]:
# ConcreteFunction object 有图定义
cube_func_int32.graph

<tensorflow.python.framework.func_graph.FuncGraph at 0x1e885796128>

In [7]:
print(cube(tf.constant([1, 2, 3])))

tf.Tensor([ 1  8 27], shape=(3,), dtype=int32)


In [8]:
print(cube_func_int32(tf.constant([1, 2, 3])))

tf.Tensor([ 1  8 27], shape=(3,), dtype=int32)


### 签名函数转SavedModel

In [9]:
logdir = os.path.join('signature_to_savedmodel')
if not os.path.exists(logdir):
    os.mkdir(logdir)

# 转化，添加成员函数即可
to_export = tf.Module()
to_export.cube = cube
tf.saved_model.save(to_export, logdir)

INFO:tensorflow:Assets written to: signature_to_savedmodel\assets


In [10]:
# 工具查看
!saved_model_cli show --dir ./signature_to_savedmodel --all


MetaGraphDef with tag-set: 'serve' contains the following SignatureDefs:

signature_def['__saved_model_init_op']:
  The given SavedModel SignatureDef contains the following input(s):
  The given SavedModel SignatureDef contains the following output(s):
    outputs['__saved_model_init_op'] tensor_info:
        dtype: DT_INVALID
        shape: unknown_rank
        name: NoOp
  Method name is: 

signature_def['serving_default']:
  The given SavedModel SignatureDef contains the following input(s):
    inputs['x'] tensor_info:
        dtype: DT_INT32
        shape: (-1)
        name: serving_default_x:0
  The given SavedModel SignatureDef contains the following output(s):
    outputs['output_0'] tensor_info:
        dtype: DT_INT32
        shape: (-1)
        name: PartitionedCall:0
  Method name is: tensorflow/serving/predict


2020-02-10 16:53:58.442151: I tensorflow/stream_executor/platform/default/dso_loader.cc:44] Successfully opened dynamic library cudart64_100.dll


In [11]:
# 加载测试
imported = tf.saved_model.load('./signature_to_savedmodel')
imported.cube(tf.constant([2]))

<tf.Tensor: id=116, shape=(1,), dtype=int32, numpy=array([8])>

## Keras模型转ConcreteFunction

In [12]:
# 加载 keras模型
loaded_keras_model = keras.models.load_model(
    './graph_def_and_weights/fashion_mnist_model.h5')
loaded_keras_model(np.ones((1, 28, 28)))

<tf.Tensor: id=608, shape=(1, 10), dtype=float32, numpy=
array([[0.11120097, 0.03724404, 0.06243855, 0.01093319, 0.03455226,
        0.00780652, 0.1908657 , 0.02144313, 0.49145293, 0.03206269]],
      dtype=float32)>

In [13]:
# 使用tf.function封装 Keras model
run_model = tf.function(lambda x : loaded_keras_model(x))
# 转化为concrete_function
keras_concrete_func = run_model.get_concrete_function(
    tf.TensorSpec(loaded_keras_model.inputs[0].shape, loaded_keras_model.inputs[0].dtype))

In [14]:
keras_concrete_func(tf.constant(np.ones((1, 28, 28), dtype = np.float32)))

<tf.Tensor: id=640, shape=(1, 10), dtype=float32, numpy=
array([[0.11120097, 0.03724404, 0.06243855, 0.01093319, 0.03455226,
        0.00780652, 0.1908657 , 0.02144313, 0.49145293, 0.03206269]],
      dtype=float32)>