In [None]:
import onnx
from onnx import helper, TensorProto
from qonnx.core.modelwrapper import ModelWrapper
from qonnx.core.datatype import DataType
import numpy as np
from qonnx.util.basic import gen_finn_dt_tensor
import onnx.version_converter as vc
from onnx.backend.test.case.node import expect

from qonnx.transformation.infer_shapes import InferShapes


In [2]:
class slice_template():
    def __init__(
            self, input_shape = 4, 
            output_shape = 2, param_shape = 1,
            starts_shape = 0, ends_shape = 2,
            axes_shape = 0, splits_shape = 1,
            input_tensor = "slice_in", output_tensor = "slice_out",
            idtype = "UNIT8", paramdt = "INT64"):
        self.inp_shp = [input_shape]
        self.out_shp = [output_shape]
        self.slice_shape = [param_shape]
        self.idt = DataType[idtype]
        self.param_dt = DataType[paramdt]
        self.input = helper.make_tensor_value_info(input_tensor, TensorProto.INT64, input_shape)
        self.output = helper.make_tensor_value_info(output_tensor, TensorProto.INT64, output_shape)
        self.slice_attr = {}
        self.slice_attr["starts"] = np.array([starts_shape],dtype=np.int64)
        self.slice_attr["ends"] = np.array([ends_shape],dtype=np.int64)
        self.slice_attr["axes"] = np.array([axes_shape],dtype=np.int64)
        self.slice_attr["splits"] = np.array([splits_shape],dtype=np.int64)
        self.opset_version = helper.make_operatorsetid("", 9)

    def make_node(self):
        self.slice_node = helper.make_node(
            "Slice",
            inputs=[self.input.name],
            outputs=[self.output.name],
            **self.slice_attr
        )

    def make_model(self):

        model_config = {}
        model_config["opset_imports"] = [self.opset_version]

        self.model = ModelWrapper(helper.make_model(
            helper.make_graph(
                [self.slice_node],
                inputs=[self.input],
                outputs=[self.output],
                name="slice_graph"
            ), **model_config)
        )
        self.model.transform(InferShapes())
        onnx.checker.check_model(self.model.model)
        self.model.save("onnx_model/Slice_model.onnx")


In [None]:
# input_shape = [4]
# output_shape = [2]
# param_shape = [1]
# idt = DataType["UINT8"]
# param_dt = DataType["INT64"]
# slice_in = helper.make_tensor_value_info("735", TensorProto.INT64, input_shape)
# slice_out = helper.make_tensor_value_info("slice_out", TensorProto.INT64, output_shape)
# slice_attr = {}
# slice_attr["starts"] = np.array([0],dtype=np.int64)
# slice_attr["ends"] = np.array([2],dtype=np.int64)
# slice_attr["axes"] = np.array([0],dtype=np.int64)