Some general points and observations regarding the `Scan` node in ONNX:

1. Scan can be used to iterate over one or more scan input tensors constructing zero or more scan output tensors. Combines ideas from general recurrences, functional programming cnostructs such as scan, fold, map and zip.
2. The attribute `body` in the node must be a graph specifying the computation to be performed in every iteration.
3. Input is the current values of the `state variables` and the current `iterated element` of the scan input. Returns values of the `state variables` and the `scan output element tensors`. (Can be greater than 1)
4. The values of the scan output tensors are concatenated over all the iterations to produce the scan output values of the scan construct.
5. The properties that make a scan node unique and different from a normal compute node are:
* Update the hidden state variable after each input computation; to be used in the processing of the next input.
* It needs to scan your inputs row by row or column by column; then keep computing the output with the updated hidden state for every input; while storing all the intermediate outputs in the form of hidden states.

In this example, I am implementing one of the six equations used to solve for an LSTM output.

    g_out = W * X + U * h_t-1 + b

* Here, `g_out` is the output of one of the four gates involved in the LSTM equations.
* `X` is the input, `h_t-1` is the previous hidden state.
* `W`, `U` and `b` are the weight matrix, recurrence matrix and the bias for that gate.

In [1]:
import onnx
import numpy as np
from qonnx.util.basic import qonnx_make_model
from finn.util.visualization import showInNetron
import onnxruntime as rt
from qonnx.util.basic import qonnx_make_model
from onnx.helper import make_tensor_value_info, make_node, make_graph, make_model, make_tensor

#### Part 1 : We first define the compute graph that we want to execute inside the `Scan` node. 
* Assuming the input size is 10 and the number of hidden states are 20.

In [2]:
ql_w = make_node("QuantizeLinear", inputs=["W_s","scale_all","zero_point_all"], outputs=["ql_ws"], name="ql_w")
clp_w = make_node("Clip", inputs=["ql_ws","min","max"], outputs=["clp_ws"], name="clp_ws")
dql_w = make_node("DequantizeLinear", inputs=["clp_ws","scale_all","zero_point_all"], outputs=["dql_ws"], name="dql_w")

ql_u = make_node("QuantizeLinear", inputs=["U_s","scale_all","zero_point_all"], outputs=["ql_us"], name="ql_u")
clp_u = make_node("Clip", inputs=["ql_us","min","max"], outputs=["clp_us"], name="clp_u")
dql_u = make_node("DequantizeLinear", inputs=["clp_us","scale_all","zero_point_all"], outputs=["dql_us"], name="dql_u")

In [3]:
# Defining the inputs and outputs of the graph we need to create.
# Input definition
inp_X = make_tensor_value_info(
"X",onnx.TensorProto.FLOAT, [10,1]
)

inp_h_t_1 = make_tensor_value_info(
"h_t-1",onnx.TensorProto.FLOAT, [20,1]
)

#Output definition

out_state = make_tensor_value_info(
"s_t", onnx.TensorProto.FLOAT, [20,1]
)

scan_out = make_tensor_value_info(
"scan_out", onnx.TensorProto.FLOAT, [20,1]
)

In [4]:
#Defining the individual nodes of the graph we want to create.
# --------------------------------------------
mul_node1 = make_node(
"MatMul", inputs=["dql_ws","X"], outputs=["out_m1"], name="mul_node1"
)

mul_node2 = make_node(
"MatMul", inputs=["dql_us","h_t-1"], outputs=["out_m2"],name="mul_node2"
)

add_node1 =  make_node(
"Add", inputs=["out_m1","out_m2"], outputs=["out_add1"],name="add_node1"
)

add_node2 = make_node(
"Add", inputs=["out_add1","b_s"], outputs=["s_t_ba"],name="add_node2"
)

sig_s = make_node(
"Sigmoid", inputs=["s_t_ba"], outputs=["s_t"],name="sig_s"
)

id_node = make_node(
"Identity", inputs=["s_t"], outputs=["scan_out"]
)

In [5]:
bias_val = np.ones([20,1],dtype=np.float32).reshape([20,1])
Ws_val = np.ones([20,10],dtype=np.float32).reshape([20,10])
Us_val = np.ones([20,20],dtype=np.float32).reshape([20,20])
gen_lstm_eq = onnx.helper.make_graph(
    nodes=[
           ql_w,
           clp_w,
           dql_w,
           ql_u,
           clp_u,
           dql_u, 
           mul_node1, 
           mul_node2, 
           add_node1, 
           add_node2,
           sig_s,
           id_node
          ],
    name = "Scan-Body",
    inputs=[inp_h_t_1,inp_X],#The order of the inputs reversed here in order to match the order of inputs of the defined scan node.
    outputs = [out_state, scan_out],
    value_info=[
            make_tensor_value_info("out_m1",onnx.TensorProto.FLOAT, [20,1]),
            make_tensor_value_info("out_m2",onnx.TensorProto.FLOAT, [20,1]),
            make_tensor_value_info("out_add1",onnx.TensorProto.FLOAT, [20,1]),
            make_tensor_value_info("s_t_ba",onnx.TensorProto.FLOAT, [20,1]),
            make_tensor_value_info("ql_ws", onnx.TensorProto.INT8, [20,10]),
            make_tensor_value_info("dql_ws",onnx.TensorProto.FLOAT, [20,10]),
            make_tensor_value_info("ql_us", onnx.TensorProto.INT8, [20,20]),
            make_tensor_value_info("dql_us",onnx.TensorProto.FLOAT, [20,20])
        ],
    initializer=[make_tensor('W_s',onnx.TensorProto.FLOAT, [20,10], (Ws_val)),
                 make_tensor('U_s',onnx.TensorProto.FLOAT, [20,20], (Us_val)),
                 make_tensor('b_s',onnx.TensorProto.FLOAT, [20,1], (bias_val)),
                 make_tensor('scale_all',onnx.TensorProto.FLOAT,[],[1]),
                 make_tensor('zero_point_all',onnx.TensorProto.INT8,[],[0]),
                 make_tensor('min',onnx.TensorProto.INT8, [],[-7]),
                 make_tensor('max',onnx.TensorProto.INT8, [], [7]),
                 make_tensor('bitwidth',onnx.TensorProto.INT32, [], [4])
                ]
)
#So some points to note here:
#1. Initializers ('W_s','U_s' and 'b_s') are a part of the model and they will not be defined in the list of the inputs.
#2. Because they are a part of the model, these initializers will not be defined again when we define the scan node later, which we will see.
#3. Scan node only cares about the inputs and outputs of the body_graph and does not care what happens inside it.

In [6]:
onnx_model = qonnx_make_model(gen_lstm_eq, producer_name="LSTM_eq")
onnx.save(onnx_model, './gen_lstm_eq.onnx')
#showInNetron('./gen_lstm_eq.onnx')#,localhost_url='xirxlabs53'

In [7]:
# Have to convert the opset version of the graph here because the clip operator in the previous version did not allow for INT8 inputs.
# It only allowed for FLOAT inputs.
from onnx import version_converter, helper
onnx_model_14 = version_converter.convert_version(onnx_model, 14)
# print(onnx_model_14)

Testing the above graph with `onnxruntime` execution

In [8]:
in_X = np.asarray(np.random.randint(low=0, high=1, size=(10,1)), dtype=np.float32)
in_h_t_1 = np.asarray(np.random.randint(low=0, high=1, size=(20,1)), dtype=np.float32)
input_dict = {}
input_dict["X"] = in_X
input_dict["h_t-1"] = in_h_t_1

sess = rt.InferenceSession(onnx_model_14.SerializeToString())
output = sess.run(None, input_dict)
print(output)

[array([[0.7310586],
       [0.7310586],
       [0.7310586],
       [0.7310586],
       [0.7310586],
       [0.7310586],
       [0.7310586],
       [0.7310586],
       [0.7310586],
       [0.7310586],
       [0.7310586],
       [0.7310586],
       [0.7310586],
       [0.7310586],
       [0.7310586],
       [0.7310586],
       [0.7310586],
       [0.7310586],
       [0.7310586],
       [0.7310586]], dtype=float32), array([[0.7310586],
       [0.7310586],
       [0.7310586],
       [0.7310586],
       [0.7310586],
       [0.7310586],
       [0.7310586],
       [0.7310586],
       [0.7310586],
       [0.7310586],
       [0.7310586],
       [0.7310586],
       [0.7310586],
       [0.7310586],
       [0.7310586],
       [0.7310586],
       [0.7310586],
       [0.7310586],
       [0.7310586],
       [0.7310586]], dtype=float32)]


2023-10-12 10:10:56.548822729 [W:onnxruntime:, graph.cc:3543 CleanUnusedInitializersAndNodeArgs] Removing initializer 'bitwidth'. It is not used by any node and should be removed from the model.


#### Part 2 : Now defining the scan node.

The node will incorporate the above graph we defined. 

* The input to the node with be the current input `X` and the previous `hidden_state`
* The output of the node will be the a tensor with the `final_state` of the gate and a tensor with all the `intermediate_states` of the gate computed for each input. {Replicating what we need in our LSTM compute where the output tensors are the `final_cell_state`, `final_hidden_state` and all the `concatenated_hidden_states`}

In [9]:
#Defining the input and output value info tensors for the scan_graph creation. These tensors act as the wrapper to the previously defined graph.

#Inputs
scan_input = make_tensor_value_info(
"scan_input",onnx.TensorProto.FLOAT, [None,10,1]
)#X ; scan input; Here 'None' defines the varibale number of inputs that can be supplied for input processing.

inp_a = make_tensor_value_info(
"inp_a",onnx.TensorProto.FLOAT, [20,1]
)# h_t-1

#Outputs
out_a = make_tensor_value_info(
"out_a", onnx.TensorProto.FLOAT, [20,1]
)#s_t

out_b = make_tensor_value_info(
"out_b", onnx.TensorProto.FLOAT, [None,20,1]
)#scan_out

In [10]:
# Defining the scan node here now
scan_node_gen_lstm_eq = make_node(
    "Scan", inputs=["inp_a","scan_input"], 
    outputs=["out_a","out_b"], 
    num_scan_inputs=1,
    body=gen_lstm_eq, domain=''
)# The order in which the nodes are defined in the inputs and outputs also matter here.

In [11]:
gen_lstm_scan_graph = make_graph(
    nodes = [scan_node_gen_lstm_eq],
    name="gen_eq_graph",
    inputs=[inp_a,scan_input],
    outputs=[out_a,out_b]
)

In [12]:
gen_scan_model = qonnx_make_model(gen_lstm_scan_graph, producer_name="eq-model")
onnx.save(gen_scan_model, './gen_scan_model.onnx')
print(gen_scan_model)

ir_version: 8
producer_name: "eq-model"
graph {
  node {
    input: "inp_a"
    input: "scan_input"
    output: "out_a"
    output: "out_b"
    op_type: "Scan"
    attribute {
      name: "body"
      g {
        node {
          input: "W_s"
          input: "scale_all"
          input: "zero_point_all"
          output: "ql_ws"
          name: "ql_w"
          op_type: "QuantizeLinear"
        }
        node {
          input: "ql_ws"
          input: "min"
          input: "max"
          output: "clp_ws"
          name: "clp_ws"
          op_type: "Clip"
        }
        node {
          input: "clp_ws"
          input: "scale_all"
          input: "zero_point_all"
          output: "dql_ws"
          name: "dql_w"
          op_type: "DequantizeLinear"
        }
        node {
          input: "U_s"
          input: "scale_all"
          input: "zero_point_all"
          output: "ql_us"
          name: "ql_u"
          op_type: "QuantizeLinear"
        }
        node {
          i

In [13]:
#Have to convert the opset version of the graph here because the clip operator in the previous version did not allow for INT8 inputs.
# It only allowed for FLOAT inputs.
from onnx import version_converter, helper
gen_scan_model_14 = version_converter.convert_version(gen_scan_model, 14)
print(gen_scan_model_14)

ir_version: 8
producer_name: "eq-model"
graph {
  node {
    input: "inp_a"
    input: "scan_input"
    output: "out_a"
    output: "out_b"
    op_type: "Scan"
    attribute {
      name: "body"
      g {
        node {
          input: "W_s"
          input: "scale_all"
          input: "zero_point_all"
          output: "ql_ws"
          name: "ql_w"
          op_type: "QuantizeLinear"
        }
        node {
          input: "ql_ws"
          input: "min"
          input: "max"
          output: "clp_ws"
          name: "clp_ws"
          op_type: "Clip"
        }
        node {
          input: "clp_ws"
          input: "scale_all"
          input: "zero_point_all"
          output: "dql_ws"
          name: "dql_w"
          op_type: "DequantizeLinear"
        }
        node {
          input: "U_s"
          input: "scale_all"
          input: "zero_point_all"
          output: "ql_us"
          name: "ql_u"
          op_type: "QuantizeLinear"
        }
        node {
          i

In [14]:
#Checking the model for any errors.
onnx.checker.check_model(gen_scan_model_14)
print(gen_scan_model.graph.value_info)

[]


In [15]:
showInNetron('./gen_scan_model.onnx')#localhost_url='xirxlabs53'

Stopping http://0.0.0.0:5901
Serving './gen_scan_model.onnx' at http://0.0.0.0:5901


Testing this new scan node with `3 inputs`. {Can change this to any number to execute `n` number of inputs.}

In [15]:
n = 3
scan_inp_X = np.asarray(np.random.randint(low=0, high=1, size=(n,10,1)), dtype=np.float32)
scan_inp_h_t_1 = np.asarray(np.random.randint(low=0, high=1, size=(20,1)), dtype=np.float32)
input_dict = {}
input_dict["scan_input"] = scan_inp_X
input_dict["inp_a"] = scan_inp_h_t_1

sess = rt.InferenceSession(gen_scan_model_14.SerializeToString())
output = sess.run(None, input_dict)
print('Final Hidden State : ',output[0])
print('All intermediate hidden states : ', output[1])

Final Hidden State :  [[1.]
 [1.]
 [1.]
 [1.]
 [1.]
 [1.]
 [1.]
 [1.]
 [1.]
 [1.]
 [1.]
 [1.]
 [1.]
 [1.]
 [1.]
 [1.]
 [1.]
 [1.]
 [1.]
 [1.]]
All intermediate hidden states :  [[[0.7310586]
  [0.7310586]
  [0.7310586]
  [0.7310586]
  [0.7310586]
  [0.7310586]
  [0.7310586]
  [0.7310586]
  [0.7310586]
  [0.7310586]
  [0.7310586]
  [0.7310586]
  [0.7310586]
  [0.7310586]
  [0.7310586]
  [0.7310586]
  [0.7310586]
  [0.7310586]
  [0.7310586]
  [0.7310586]]

 [[0.9999999]
  [0.9999999]
  [0.9999999]
  [0.9999999]
  [0.9999999]
  [0.9999999]
  [0.9999999]
  [0.9999999]
  [0.9999999]
  [0.9999999]
  [0.9999999]
  [0.9999999]
  [0.9999999]
  [0.9999999]
  [0.9999999]
  [0.9999999]
  [0.9999999]
  [0.9999999]
  [0.9999999]
  [0.9999999]]

 [[1.       ]
  [1.       ]
  [1.       ]
  [1.       ]
  [1.       ]
  [1.       ]
  [1.       ]
  [1.       ]
  [1.       ]
  [1.       ]
  [1.       ]
  [1.       ]
  [1.       ]
  [1.       ]
  [1.       ]
  [1.       ]
  [1.       ]
  [1.       ]
  [1.  

2023-10-12 10:11:23.265230994 [W:onnxruntime:, graph.cc:3543 CleanUnusedInitializersAndNodeArgs] Removing initializer 'bitwidth'. It is not used by any node and should be removed from the model.
2023-10-12 10:11:23.267105686 [W:onnxruntime:, graph.cc:3543 CleanUnusedInitializersAndNodeArgs] Removing initializer 'ql_ws'. It is not used by any node and should be removed from the model.
2023-10-12 10:11:23.267140631 [W:onnxruntime:, graph.cc:3543 CleanUnusedInitializersAndNodeArgs] Removing initializer 'ql_us'. It is not used by any node and should be removed from the model.
2023-10-12 10:11:23.267167421 [W:onnxruntime:, graph.cc:3543 CleanUnusedInitializersAndNodeArgs] Removing initializer 'max'. It is not used by any node and should be removed from the model.
2023-10-12 10:11:23.267193620 [W:onnxruntime:, graph.cc:3543 CleanUnusedInitializersAndNodeArgs] Removing initializer 'min'. It is not used by any node and should be removed from the model.
2023-10-12 10:11:23.267220034 [W:onnxrunt

The first output tensor is the `final_hidden_state` and the second output tensor are all the `n_intermediate_states` {including the final_hidden_state}.