In [None]:
from onnx import TensorProto
from qonnx.core.modelwrapper import ModelWrapper
from slice_template import slice_node

# Open the Tinyyolo model
tmp_model = ModelWrapper("tinyyolo_infershapes.onnx")

# Create a dictionary to store the slice nodes in 
update_dict = {}
# These are the inputs of a slice node in the correct order
key_list = ["input", "starts", "ends", "axes", "splits", "output"]
# These are the tensors from the replaced slice nodes
drop_tensors = []

# Loop over all the nodes in the graph at are of op_type Slice
for s in tmp_model.get_nodes_by_op_type("Slice"):
    # Get the node index (maybe can remove?!)
    node_inx = tmp_model.get_node_index(s)
    # Store all tensor value info, initialization, and shapes for the current slice node
    vinfo = [tmp_model.get_tensor_valueinfo(inp) for inp in s.input]
    init = [tmp_model.get_initializer(inp) for inp in s.input]
    shapes = [tmp_model.get_tensor_shape(i) for i in s.input]
    # Get the input tensor names
    t_names = [i for i in s.input]
    
    # Create a dict of attributes and initalize an empty dict [Tensor_name, shape, initialization value]
    attr = {}
    for k in key_list:
        attr[k] = [None, None, None]
    
    # Loop over the len of shapes (not all slice nodes contain a splits value)
    for ind in range(len(shapes)):
        # Store the vlaue in attribute
        attr[key_list[ind]] = [t_names[ind], shapes[ind], init[ind]]
        # Append list of tensors for this node
        drop_tensors.append(t_names[ind])

    # Do the same as above but for the output of the slice node
    for i in s.output:
        attr["output"] = [i, tmp_model.get_tensor_shape(i)]
        drop_tensors.append(i)
    # Print some stats
    print("*"*110)
    print(attr)
    # Create a new slice node with attr instead of inputs
    x = slice_node(input_shape=attr["input"][1], output_shape=attr["output"][1],
                   starts_value=attr["starts"][2], ends_value=attr["ends"][2], axes_value=attr["axes"][2],
                   input_tensor=attr["input"][0], output_tensor=attr["output"][0],
                   dtype=TensorProto.INT64,node_name = s.name)
    # Update dict is a dict that stores the new slice node with key = slice name, index and node
    update_dict[s.name] = [node_inx, x.make_node()]
    # store the model of the single slice node. This is for double checking
    x.make_model(s.name+".onnx")


# Create a new model
In this part of the notebook I will create a new model by removing the old slice nodes from the onnx file and replacing them with the new slice nodes created earlier in this notebook

In [None]:
import onnx
# Load in the tinyyolo model
model = ModelWrapper("tinyyolo_infershapes.onnx")

# Create empty lists 
vinfo, keep_nodes, new_nodes = [], [], []
# Create an empty dict for initialization tensors
t_init = {}
tensors = model.get_all_tensor_names()

# Loop over all tenors in the model and remove the ones that belong to the dropped nodes
for t in tensors:
    # Store the value info and initialization values if they exist.
    if not (t in drop_tensors):
        vinfo.append(model.get_tensor_valueinfo(t))
        if not model.get_initializer(t) is None:
            t_init[t] = model.get_initializer(t)
# These are the nodes that need to be removed
replace_nodes = update_dict.keys()

# Make a list of nodes to keep
for n in model.graph.node:
    if not n.name in replace_nodes:
        keep_nodes.append(n)

# Make a list of the new nodes
new_nodes = [v[1] for k,v in update_dict.items()]
for nnode in new_nodes:
    keep_nodes.append(nnode)
# Create a new model
new_model = ModelWrapper(onnx.helper.make_model(
    onnx.helper.make_graph(
        nodes= keep_nodes,
        name="tinyyolo_infershapes_updated",
        inputs=[model.graph.input[0]],
        outputs=[model.graph.output[0]],
        value_info=vinfo[0:-2]
        )
    )
)

for k,v in t_init.items():
    new_model.set_initializer(k,v)
new_model.save("tinyyolo_slice_update.onnx")
