In [1]:
import onnx
from onnx import helper, TensorProto
from qonnx.core.modelwrapper import ModelWrapper
from qonnx.transformation.infer_datatypes import InferDataTypes
from qonnx.transformation.infer_shapes import InferShapes
from qonnx.core.datatype import DataType
from qonnx.util.basic import gen_finn_dt_tensor
import os

# Define a finder_fx
def find_input_node(x):
    return 'x' in x.input

split_node = "MultiThreshold_10"
# Open the tinyyolo model
model = ModelWrapper("../tinyyolo_infershapes.onnx")

# Get onnx nodes
nodes = model.graph.node
passed_tensors,inialized_tensors=[],[]

# Get all tensors in the model
all_tensors = model.get_all_tensor_names()
for ten in all_tensors:
    # Make a list of all tensors which are initalized
    if not model.get_initializer(ten) is None:
        inialized_tensors.append(ten)
        
# Make a list of all nodes passed on the way to the split node
for n in nodes:
    if n.name != split_node:
        passed_tensors.append(n.input)
    else:
        s_node = n        
        break
# Create a dict which has the passed initalizers
init_tens = {}
for pt in passed_tensors:
    for t in pt:
        if t in inialized_tensors:
            init_tens[t] = model.get_initializer(t)
start_node = s_node.input[0]

# Find nodes upstream of the cut node
upstream_nodes = model.find_upstream(start_node,find_input_node)

# Reorder the nodes
up_n_ordered = []
for n in reversed(upstream_nodes):
    up_n_ordered.append(n)


In [2]:

# Get input and output tensor shapes
m_in = model.graph.input[0].name
ish = model.get_tensor_shape(m_in)
osh = model.get_tensor_shape(up_n_ordered[-1].output[0])

# Make tensor value info for the input and output of the model
inputs = helper.make_tensor_value_info(m_in,TensorProto.FLOAT,ish)
outputs = helper.make_tensor_value_info(up_n_ordered[-1].output[0],TensorProto.FLOAT,osh)


# Make a value info list to include in the graph
value_info = []
for t in init_tens.keys():
    value_info.append(
        model.get_tensor_valueinfo(t)
    )


In [7]:
# Create a graph
new_graph = helper.make_graph(
    name ="new_graph",
    inputs=[inputs],
    outputs=[outputs],
    value_info=value_info,
    nodes=up_n_ordered
)

# Create a new model that only contains the nodes desired
split_model = ModelWrapper(helper.make_model(new_graph))
# Set initalizer using the import model
for t in init_tens.keys():
    split_model.set_initializer(t,model.get_initializer(t))      
    dt = model.get_tensor_datatype(t)
    if  dt != "FLOAT32":
        split_model.set_tensor_datatype(t,DataType[str(dt)])
    
split_model = split_model.transform(InferShapes())
split_model = split_model.transform(InferDataTypes())
if not os.path.exists("model_files"):
    os.mkdir("model_files")
else:
    model_name = "model_files/split_model_{}.onnx".format(split_node)

split_model.save(model_name)

print("Split at node {}, and saved to {}".format(split_node,model_name))

INT4