In [1]:
%reload_ext watermark
%reload_ext autoreload
%autoreload 2
%watermark -p numpy,sklearn,pandas
%watermark -p cv2,PIL,matplotlib
%watermark -p torch,torchvision,torchaudio
%watermark -p tensorflow,tensorboard,tflite
%watermark -p onnx,onnxruntime,tensorrt,tvm
%matplotlib inline
%config InlineBackend.figure_format='retina'
%config IPCompleter.use_jedi = False

from IPython.display import display, Markdown, HTML, Image, Javascript
from IPython.core.magic import register_line_cell_magic, register_line_magic, register_cell_magic
display(HTML('<style>.container { width:%d%% !important; }</style>' % 95))

import sys, os, io, time, random, math
import json, base64, requests, shutil
import os.path as osp
import numpy as np

def _IMPORT(x):
    try:
        x = x.strip()
        if x.startswith('https://'):
            x = x[8:]
        if not x.endswith('.py'):
            x = x + '.py'
        if x[0] == '/':
            with open(x) as fr:
                x = fr.read()
        else:
            x = x.replace('blob/main/', '').replace('blob/master/', '')
            if x.startswith('raw.githubusercontent.com'):
                uri = 'https://' + x
                x = requests.get(uri)
                if x.status_code == 200:
                    x = x.text
            elif x.startswith('github.com'):
                uri = x.replace('github.com', 'raw.githubusercontent.com')
                mod = uri.split('/')
                for s in ['main', 'master']:
                    uri = 'https://' + '/'.join(mod[:3]) + s + '/'.join(mod[-3:])
                    x = requests.get(uri)
                    if x.status_code == 200:
                        x = x.text
                        break
            elif x.startswith('gitee.com'):
                mod = x.split('/')
                for s in ['/raw/main/', '/raw/master/']:
                    uri = 'https://' + '/'.join(mod[:3]) + s + '/'.join(mod[3:])
                    x = requests.get(uri)
                    if x.status_code == 200:
                        x = x.text
                        break
        exec(x, globals())
    except:
        pass

def _DIR(x, dumps=True, ret=True):
    attrs = sorted([y for y in dir(x) if not y.startswith('_')])
    result = '%s: %s' % (str(type(x))[8:-2], json.dumps(attrs) if dumps else attrs)
    if ret:
        return result
    print(result)

numpy 1.19.5
sklearn 0.0
pandas 1.1.5
cv2 4.5.3
PIL 8.3.1
matplotlib 3.3.4
torch 1.8.1+cu101
torchvision 0.9.1+cu101
torchaudio not installed
tensorflow 2.6.0
tensorboard 2.6.0
tflite 2.4.0
onnx 1.10.1
onnxruntime 1.8.1
tensorrt not installed
tvm not installed


In [2]:
import tensorflow as tf
from tensorflow.core.framework import types_pb2, graph_pb2, attr_value_pb2
from tensorflow.python.framework.convert_to_constants import convert_variables_to_constants_v2

In [3]:
MODEL_PATH = '/data/nb_data/saved_models'

## Load Model Graph

### Using Saved Models

In [4]:
model_loaded = tf.saved_model.load(MODEL_PATH, tags=['serve']) # TODO attr not found value
_DIR(model_loaded)

'tensorflow.python.saved_model.load.Loader._recreate_base_user_object.<locals>._UserObject: ["base_model", "call_and_return_all_conditional_losses", "conv_3x3_layer", "dropout_layer", "fc_layers", "graph_debug_info", "input_projection", "input_projection2", "keras_api", "pos_encoding", "pos_encoding2", "regularization_losses", "signatures", "temporal_bn_layers", "temporal_conv_layers", "tensorflow_git_version", "tensorflow_version", "trainable_variables", "transformer_layers", "transformer_layers2", "variables", "within_period_fc_layers"]'

#### ConcreteFunction

In [5]:
signatures = model_loaded.signatures
_DIR(signatures), '-'*90, signatures.keys()

('tensorflow.python.saved_model.signature_serialization._SignatureMap: ["get", "items", "keys", "values"]',
 '------------------------------------------------------------------------------------------',
 KeysView(_SignatureMap({'serving_default': <ConcreteFunction signature_wrapper(*, input_1) at 0x7FEF4DC3E828>})))

In [6]:
frozen_func = signatures['serving_default']
_DIR(frozen_func), '-'*90, frozen_func.name

('tensorflow.python.saved_model.load._WrapperFunction: ["add_gradient_functions_to_graph", "add_to_graph", "captured_inputs", "function_def", "graph", "inputs", "name", "output_dtypes", "output_shapes", "outputs", "pretty_printed_signature", "structured_input_signature", "structured_outputs", "trainable_variables", "variables"]',
 '------------------------------------------------------------------------------------------',
 b'__inference_signature_wrapper_2710')

In [7]:
len(frozen_func.inputs), len(frozen_func.outputs)

(227, 2)

In [8]:
frozen_func = convert_variables_to_constants_v2(frozen_func, lower_control_flow=False, aggressive_inlining=True)

In [9]:
len(frozen_func.inputs), len(frozen_func.outputs)

(1, 2)

In [10]:
frozen_func.inputs, '-'*60, frozen_func.outputs

([<tf.Tensor 'input_1:0' shape=(None, 64, 112, 112, 3) dtype=float32>],
 '------------------------------------------------------------',
 [<tf.Tensor 'Identity:0' shape=(None, None, 32) dtype=float32>,
  <tf.Tensor 'Identity_1:0' shape=(None, None, 1) dtype=float32>])

#### FunctionDef & OpDef 

In [11]:
_DIR(frozen_func.function_def), '-'*90, _DIR(frozen_func.function_def.signature), '-'*90, _DIR(frozen_func.function_def.node_def)

('tensorflow.core.framework.function_pb2.FunctionDef: ["ArgAttrEntry", "ArgAttrs", "AttrEntry", "ByteSize", "Clear", "ClearExtension", "ClearField", "ControlRetEntry", "CopyFrom", "DESCRIPTOR", "DiscardUnknownFields", "Extensions", "FindInitializationErrors", "FromString", "HasExtension", "HasField", "IsInitialized", "ListFields", "MergeFrom", "MergeFromString", "ParseFromString", "RegisterExtension", "ResourceArgUniqueIdEntry", "RetEntry", "SerializePartialToString", "SerializeToString", "SetInParent", "UnknownFields", "WhichOneof", "arg_attr", "attr", "control_ret", "node_def", "resource_arg_unique_id", "ret", "signature"]',
 '------------------------------------------------------------------------------------------',
 'tensorflow.core.framework.op_def_pb2.OpDef: ["ArgDef", "AttrDef", "ByteSize", "Clear", "ClearExtension", "ClearField", "CopyFrom", "DESCRIPTOR", "DiscardUnknownFields", "Extensions", "FindInitializationErrors", "FromString", "HasExtension", "HasField", "IsInitialized"

In [12]:
frozen_func.function_def.attr, frozen_func.function_def.signature.name

({}, '__inference_pruned_35300')

#### FuncGraph

In [13]:
_DIR(frozen_func.graph)

'tensorflow.python.framework.func_graph.FuncGraph: ["add_capture", "add_to_collection", "add_to_collections", "as_default", "as_graph_def", "as_graph_element", "building_function", "capture", "capture_by_value", "capture_call_time_value", "capture_distributed_variable", "capture_eager_tensor", "captured", "captures", "clear_captures", "clear_collection", "collections", "colocate_with", "container", "control_captures", "control_dependencies", "control_outputs", "create_op", "deferred_external_captures", "deferred_internal_captures", "device", "external_captures", "finalize", "finalized", "get_all_collection_keys", "get_collection", "get_collection_ref", "get_name_scope", "get_operation_by_name", "get_operations", "get_tensor_by_name", "gradient_override_map", "graph_def_versions", "inputs", "internal_captures", "is_control_flow_graph", "is_feedable", "is_fetchable", "mark_as_unsaveable", "name", "name_scope", "outer_graph", "output_shapes", "output_types", "outputs", "pop_capture", "pre

In [14]:
frozen_func.graph.inputs, '-'*60, frozen_func.graph.outputs

([<tf.Tensor 'input_1:0' shape=(None, 64, 112, 112, 3) dtype=float32>],
 '------------------------------------------------------------',
 [<tf.Tensor 'Identity:0' shape=(None, None, 32) dtype=float32>,
  <tf.Tensor 'Identity_1:0' shape=(None, None, 1) dtype=float32>])

In [15]:
operations = frozen_func.graph.get_operations()
_DIR(operations), '-'*60, _DIR(operations[0])

('list: ["append", "clear", "copy", "count", "extend", "index", "insert", "pop", "remove", "reverse", "sort"]',
 '------------------------------------------------------------',
 'tensorflow.python.framework.ops.Operation: ["colocation_groups", "control_inputs", "device", "get_attr", "graph", "inputs", "name", "node_def", "op_def", "outputs", "run", "traceback", "type", "values"]')

In [16]:
for i, op in enumerate(operations[:2] + operations[-2:]):
    if i == 2:
        print('*'*80)
    print('inputs:', op.inputs, 'outputs:', op.outputs, 'type:', op.type, 'values:', op.values)
    print('-'*80)
    print(op)
    print('+'*80)

inputs: () outputs: [<tf.Tensor 'input_1:0' shape=(None, 64, 112, 112, 3) dtype=float32>] type: Placeholder values: <bound method Operation.values of <tf.Operation 'input_1' type=Placeholder>>
--------------------------------------------------------------------------------
name: "input_1"
op: "Placeholder"
attr {
  key: "dtype"
  value {
    type: DT_FLOAT
  }
}
attr {
  key: "shape"
  value {
    shape {
      dim {
        size: -1
      }
      dim {
        size: 64
      }
      dim {
        size: 112
      }
      dim {
        size: 112
      }
      dim {
        size: 3
      }
    }
  }
}

++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
inputs: () outputs: [<tf.Tensor 'StatefulPartitionedCall/resnet_period_estimator_29/mul_1/x:0' shape=() dtype=float32>] type: Const values: <bound method Operation.values of <tf.Operation 'StatefulPartitionedCall/resnet_period_estimator_29/mul_1/x' type=Const>>
-------------------------------------------------

### Using Frozen pb

In [17]:
tf.compat.v1.reset_default_graph()
with tf.Graph().as_default() as frozen_pb_graph:
    graph_def = tf.compat.v1.GraphDef()
    with open(f'{MODEL_PATH}/frozen_graph.pb', "rb") as f:
        graph_def.ParseFromString(f.read())
    tf.import_graph_def(graph_def)

In [18]:
from google.protobuf import text_format

tf.compat.v1.reset_default_graph()
with tf.Graph().as_default() as frozen_pbtxt_graph:
    graph_def = tf.compat.v1.GraphDef()
    with open(f'{MODEL_PATH}/frozen_graph.pbtxt', "rb") as f:
        text_format.Parse(f.read(), graph_def)
        tf.import_graph_def(graph_def)

### Using tf2onnx

In [19]:
from tf2onnx.tf_loader import from_saved_model

In [20]:
graph_def, inputs, outputs, initialized_tables, tensors_to_rename = from_saved_model(
    MODEL_PATH, None, None, tag='serve', return_initialized_tables=True, return_tensors_to_rename=True)

'--signature_def' not specified, using first signature: serving_default


Instructions for updating:
Use `tf.compat.v1.graph_util.extract_sub_graph`


In [21]:
_DIR(graph_def), '-'*60, inputs, '-'*60, outputs

('tensorflow.core.framework.graph_pb2.GraphDef: ["ByteSize", "Clear", "ClearExtension", "ClearField", "CopyFrom", "DESCRIPTOR", "DiscardUnknownFields", "Extensions", "FindInitializationErrors", "FromString", "HasExtension", "HasField", "IsInitialized", "ListFields", "MergeFrom", "MergeFromString", "ParseFromString", "RegisterExtension", "SerializePartialToString", "SerializeToString", "SetInParent", "UnknownFields", "WhichOneof", "library", "node", "version", "versions"]',
 '------------------------------------------------------------',
 ['input_1:0'],
 '------------------------------------------------------------',
 ['Identity:0', 'Identity_1:0'])

In [22]:
tf.compat.v1.reset_default_graph()
with tf.Graph().as_default() as tf2onnx_graph:
    tf.import_graph_def(graph_def)

## Graph  & Node

In [23]:
_DIR(frozen_func.graph), '-'*90, _DIR(frozen_pb_graph), '-'*90, _DIR(frozen_pbtxt_graph), '-'*90, _DIR(tf2onnx_graph)

('tensorflow.python.framework.func_graph.FuncGraph: ["add_capture", "add_to_collection", "add_to_collections", "as_default", "as_graph_def", "as_graph_element", "building_function", "capture", "capture_by_value", "capture_call_time_value", "capture_distributed_variable", "capture_eager_tensor", "captured", "captures", "clear_captures", "clear_collection", "collections", "colocate_with", "container", "control_captures", "control_dependencies", "control_outputs", "create_op", "deferred_external_captures", "deferred_internal_captures", "device", "external_captures", "finalize", "finalized", "get_all_collection_keys", "get_collection", "get_collection_ref", "get_name_scope", "get_operation_by_name", "get_operations", "get_tensor_by_name", "gradient_override_map", "graph_def_versions", "inputs", "internal_captures", "is_control_flow_graph", "is_feedable", "is_fetchable", "mark_as_unsaveable", "name", "name_scope", "outer_graph", "output_shapes", "output_types", "outputs", "pop_capture", "pr

### GraphDef

In [24]:
graph_def = frozen_func.graph.as_graph_def()
# graph_def = frozen_pb_graph.as_graph_def()
# graph_def = tf2onnx_graph.as_graph_def()
_DIR(graph_def)

'tensorflow.core.framework.graph_pb2.GraphDef: ["ByteSize", "Clear", "ClearExtension", "ClearField", "CopyFrom", "DESCRIPTOR", "DiscardUnknownFields", "Extensions", "FindInitializationErrors", "FromString", "HasExtension", "HasField", "IsInitialized", "ListFields", "MergeFrom", "MergeFromString", "ParseFromString", "RegisterExtension", "SerializePartialToString", "SerializeToString", "SetInParent", "UnknownFields", "WhichOneof", "library", "node", "version", "versions"]'

### NodeDef

In [25]:
source_graph_def = graph_def
_DIR(source_graph_def), '-'*90, source_graph_def.versions

('tensorflow.core.framework.graph_pb2.GraphDef: ["ByteSize", "Clear", "ClearExtension", "ClearField", "CopyFrom", "DESCRIPTOR", "DiscardUnknownFields", "Extensions", "FindInitializationErrors", "FromString", "HasExtension", "HasField", "IsInitialized", "ListFields", "MergeFrom", "MergeFromString", "ParseFromString", "RegisterExtension", "SerializePartialToString", "SerializeToString", "SetInParent", "UnknownFields", "WhichOneof", "library", "node", "version", "versions"]',
 '------------------------------------------------------------------------------------------',
 producer: 808)

In [26]:
_DIR(source_graph_def.node), '-'*60, source_graph_def.node[0:2]

('google.protobuf.pyext._message.RepeatedCompositeContainer: ["MergeFrom", "add", "append", "extend", "insert", "pop", "remove", "reverse", "sort"]',
 '------------------------------------------------------------',
 [name: "input_1"
  op: "Placeholder"
  attr {
    key: "dtype"
    value {
      type: DT_FLOAT
    }
  }
  attr {
    key: "shape"
    value {
      shape {
        dim {
          size: -1
        }
        dim {
          size: 64
        }
        dim {
          size: 112
        }
        dim {
          size: 112
        }
        dim {
          size: 3
        }
      }
    }
  },
  name: "StatefulPartitionedCall/resnet_period_estimator_29/mul_1/x"
  op: "Const"
  attr {
    key: "dtype"
    value {
      type: DT_FLOAT
    }
  }
  attr {
    key: "value"
    value {
      tensor {
        dtype: DT_FLOAT
        tensor_shape {
        }
        float_val: -1.0
      }
    }
  }])

In [27]:
node0 = source_graph_def.node[0]
_DIR(node0), '-'*60, node0.name, '-'*30, node0.input, '-'*30, node0.op, '-'*30, node0.attr, type(node0.attr)

('tensorflow.core.framework.node_def_pb2.NodeDef: ["AttrEntry", "ByteSize", "Clear", "ClearExtension", "ClearField", "CopyFrom", "DESCRIPTOR", "DiscardUnknownFields", "ExperimentalDebugInfo", "Extensions", "FindInitializationErrors", "FromString", "HasExtension", "HasField", "IsInitialized", "ListFields", "MergeFrom", "MergeFromString", "ParseFromString", "RegisterExtension", "SerializePartialToString", "SerializeToString", "SetInParent", "UnknownFields", "WhichOneof", "attr", "device", "experimental_debug_info", "input", "name", "op"]',
 '------------------------------------------------------------',
 'input_1',
 '------------------------------',
 [],
 '------------------------------',
 'Placeholder',
 '------------------------------',
 {'dtype': type: DT_FLOAT
 , 'shape': shape {
   dim {
     size: -1
   }
   dim {
     size: 64
   }
   dim {
     size: 112
   }
   dim {
     size: 112
   }
   dim {
     size: 3
   }
 }
 },
 google.protobuf.pyext._message.MessageMapContainer)

In [28]:
float_val_node = None
tensor_content_node = None
for node in source_graph_def.node:
    for attr in node.attr:
        if attr == 'value':
            tensor = node.attr[attr].tensor
            if tensor.dtype == types_pb2.DT_FLOAT:
                if tensor.float_val and not float_val_node:
                    float_val_node = node
                if tensor.tensor_content and not tensor_content_node:
                    tensor_shape = [x.size for x in tensor.tensor_shape.dim]
                    if np.prod(tensor_shape) < 48:
                        tensor_content_node = node
    if float_val_node and tensor_content_node:
        break

In [29]:
float_val_node, '-'*60, tensor_content_node

(name: "StatefulPartitionedCall/resnet_period_estimator_29/mul_1/x"
 op: "Const"
 attr {
   key: "dtype"
   value {
     type: DT_FLOAT
   }
 }
 attr {
   key: "value"
   value {
     tensor {
       dtype: DT_FLOAT
       tensor_shape {
       }
       float_val: -1.0
     }
   }
 },
 '------------------------------------------------------------',
 name: "unknown_174"
 op: "Const"
 attr {
   key: "dtype"
   value {
     type: DT_FLOAT
   }
 }
 attr {
   key: "value"
   value {
     tensor {
       dtype: DT_FLOAT
       tensor_shape {
         dim {
           size: 32
         }
       }
       tensor_content: "\020\264\3159\373\352b\270hf\020;19h;\300\003\202\271\022\264\211\272\331v\013\270\3761Q:J\rA\273L\320\265:\t\273\352\273;\222\007\2724\034\216\272\037-\327\272T\\\333\272\360Q\233\273\260_v\272\"\032\035;*\367\237:\217p\310:\274\274C8\372\252\345\273\\5)\271_\260\232\272\336-W;{\241\226\267\211\304\216\273\334\346\227:(\023\264\270E\367\307\272\346?\002\273\375\344\215\271"
 

## Convert Weight to Float16 

In [30]:
from tensorflow.core.framework import types_pb2, graph_pb2, attr_value_pb2
from google.protobuf import text_format

In [31]:
target_graph_def = graph_pb2.GraphDef()
target_graph_def.versions.CopyFrom(source_graph_def.versions)

In [32]:
target_graph_def

versions {
  producer: 808
}

In [33]:
dtype_f16 = types_pb2.DT_HALF
for node in source_graph_def.node:
    new_node = target_graph_def.node.add()
    new_node.op = node.op
    new_node.name = node.name
    new_node.input.extend(node.input)
    for attr in node.attr:
        if node.attr[attr].type == types_pb2.DT_FLOAT:
            new_node.attr[attr].type = dtype_f16
        if attr == "value":
            tensor = node.attr[attr].tensor
            if tensor.dtype == types_pb2.DT_FLOAT:
                if tensor.float_val:
                    float_val = tf.make_ndarray(node.attr[attr].tensor)
                    new_node.attr[attr].tensor.CopyFrom(tf.make_tensor_proto(float_val, dtype=dtype_f16))
                    continue
                if tensor.tensor_content:
                    tensor_shape = [x.size for x in tensor.tensor_shape.dim]
                    tensor_weights = np.reshape(tf.make_ndarray(tensor), tensor_shape)
                    new_node.attr[attr].tensor.CopyFrom(tf.make_tensor_proto(tensor_weights, dtype=dtype_f16))
                    continue
        new_node.attr[attr].CopyFrom(node.attr[attr])

In [34]:
tf.io.write_graph(source_graph_def, logdir=f'{MODEL_PATH}', name='frozen_graph_f32.pb', as_text=False);
tf.io.write_graph(target_graph_def, logdir=f'{MODEL_PATH}', name='frozen_graph_f16.pb', as_text=False);

In [35]:
!ls -lh $MODEL_PATH/frozen_graph_f*.pb

-rw-rw-rw- 1 root root 50M Oct  8 20:00 /data/nb_data/saved_models/frozen_graph_f16.pb
-rw-rw-rw- 1 root root 99M Oct  8 20:00 /data/nb_data/saved_models/frozen_graph_f32.pb


## References

- [post-training-quantization-of-tensorflow-model-to-fp16][1]
- [tf.compat.v1.wrap_function][2]


[2]: https://tensorflow.google.cn/api_docs/python/tf/compat/v1/wrap_function
[1]: https://medium.com/@fanzongshaoxing/post-training-quantization-of-tensorflow-model-to-fp16-8d66b9dfa77f