In [27]:
import onnx
import onnx_graphsurgeon as gs
import onnxruntime as ort
import torch
import numpy as np
from rich import print

In [28]:
def layer_wise(file,input_value):

    ort_session_1 = ort.InferenceSession(file)
    org_outputs = [x.name for x in ort_session_1.get_outputs()]

    model = onnx.load(file)
    for node in model.graph.node:
        for output in node.output:
            if output not in org_outputs:
                model.graph.output.extend([onnx.ValueInfoProto(name=output)])
    
    ort_session = ort.InferenceSession(model.SerializeToString())

    outputs = [x.name for x in ort_session.get_outputs()]
    inputs = [x.name for x in ort_session.get_inputs()]
    input_shape = ort_session.get_inputs()[0].shape


    ort_outs = ort_session.run(None, input_feed={inputs[0]:input_value})


    from collections import OrderedDict
    ort_outs = OrderedDict(zip(outputs, ort_outs))

    print(f"Input Shape :{input_value.shape }Input Value :{input_value}")

    for key in ort_outs.keys():
        print(f"Layer : {key} Shape :{ort_outs[key].shape} Outputs :{ort_outs[key]}")

## ReduceSum Mul axis 0

In [None]:

input_tensor = gs.Variable(name="input_tensor", dtype=np.int32, shape=(2, 2))

# ReduceSum Node
reduce_sum_node = gs.Node(op="ReduceSum", 
                          name="ReduceSum_Node",
                          inputs=[input_tensor], 
                          outputs=[gs.Variable(name="reduced_output", dtype=np.int32)],
                          attrs={"axes": [0]})  


# Mul Node
mul_constant = gs.Constant(name="mul_constant", values=np.array([2], dtype=np.int32))
mul_node = gs.Node(op="Mul", 
                   name="Mul_Node",
                   inputs=[reduce_sum_node.outputs[0], mul_constant], 
                   outputs=[gs.Variable(name="mul_output", dtype=np.int32)])


graph = gs.Graph(nodes=[reduce_sum_node, mul_node], 
                 inputs=[input_tensor], 
                 outputs=[mul_node.outputs[0]])


onnx_model = gs.export_onnx(graph)

onnx.save(onnx_model, "reduce_sum_mul_axis0_2x2.onnx")

## MatMul axis 0

In [None]:
input_tensor = gs.Variable("input_tensor", np.int32, (2, 2))

trans_tensor1 = gs.Variable("trans_tensor1", np.int32, (2, 2))
transpose_node1=gs.Node(op="Transpose",name="Transpose1",inputs=[input_tensor],outputs=[trans_tensor1])


weight = gs.Constant("matmul_weight", np.full((2, 1), 2, dtype=np.int32))
matmul_node = gs.Node(op="MatMul",
                      name="MatMul_SumMul",
                      inputs=[trans_tensor1, weight],
                      outputs=[gs.Variable("output_tensor", np.int32, (2, 1))])


trans_tensor2 = gs.Variable("trans_tensor2", np.int32, (1, 2))
transpose_node2=gs.Node(op="Transpose",name="Transpose2",inputs=[matmul_node.outputs[0]],outputs=[trans_tensor2])

graph = gs.Graph(nodes=[transpose_node1,matmul_node,transpose_node2], inputs=[input_tensor], outputs=[trans_tensor2])

onnx_model = gs.export_onnx(graph)
onnx.save(onnx_model, "matmul_axis0_2x2.onnx")


## MatMul axis 1

In [None]:
input_tensor = gs.Variable("input_tensor", np.int32, (2,2))


weight = gs.Constant("matmul_weight", np.full((2, 1), 2, dtype=np.int32))


matmul_node = gs.Node(op="MatMul",
                      name="MatMul_SumMul",
                      inputs=[input_tensor, weight],
                      outputs=[gs.Variable("output_tensor", np.int32, (2, 1))])


graph = gs.Graph(nodes=[matmul_node], inputs=[input_tensor], outputs=[matmul_node.outputs[0]])

onnx_model = gs.export_onnx(graph)
onnx.save(onnx_model, "matmul_axis1_2x2.onnx")


## Inference

In [37]:
input_value=np.random.randint(1,5,size=(2,2))

In [38]:
print(layer_wise(file="reduce_sum_mul_axis0_2x2.onnx",input_value=input_value))

print(layer_wise(file="matmul_axis0_2x2.onnx",input_value=input_value))

In [34]:
print(layer_wise(file="reduce_sum_mul_axis1.onnx",input_value=input_value))
print(layer_wise(file="matmul_axis1.onnx",input_value=input_value))