In [11]:
import onnx
import onnx_graphsurgeon as gs
import onnxruntime as ort
import torch
import numpy as np
from rich import print
from onnx import TensorProto

## ScatterND using Graph Surgeon 

In [12]:
## updates rank = data rank+indices rank - indices.shape[-1]-1
## output shape same as data shape 

In [13]:
data = gs.Variable("data", dtype=TensorProto.FLOAT, shape=[8])
indices = gs.Variable("indices", dtype=TensorProto.INT64, shape=[4, 1])
updates = gs.Variable("updates", dtype=TensorProto.FLOAT, shape=[4])

scatter_output=gs.Variable("scatter_output",dtype=TensorProto.FLOAT,shape=[8])

In [14]:
scatter_nd= gs.Node(op="ScatterND",name="scatter_node",inputs=[data,indices,updates],outputs=[scatter_output])

graph=gs.Graph(nodes=[scatter_nd],inputs=[data,indices,updates],outputs=[scatter_output],opset=18)

onnx.save(gs.export_onnx(graph),"./models/scatternd.onnx")

## Inputs

In [None]:

#INPUTS

ort_session=ort.InferenceSession("./models/scatternd.onnx")

input_values =  [np.random.randn(*x.shape).astype(np.float32) if i!=1 else np.random.randint(0,5,x.shape).astype(np.int64) for i,x in enumerate(ort_session.get_inputs())]
input_dict = dict(zip((x.name for x in ort_session.get_inputs()),input_values))

print(input_dict)

d,i,u=input_values

data = torch.from_numpy(d)
indices = torch.from_numpy(i)
updates = torch.from_numpy(u)


## Replacements

In [None]:
import torch

class ScatterNDModule(torch.nn.Module):
    def forward(self, data, indices, updates):
        # data: (8,)
        # indices: (M, 1)
        # updates: (M,)
        
        output = data.clone()
        
        M = indices.size(0)
        for i in range(M):
            idx = indices[i, 0]
            part1 = output[0:idx] 
            print("Part1",part1)
            update = updates[i:i+1]
            print("Update",update)
            part2 = output[idx + 1:output.size(0)]
            print("Part2",part2)
            output = torch.cat([part1, update, part2], dim=0)
            print("Output",output)

        return output


if __name__ == "__main__":
    module = ScatterNDModule()
    # output = module(data, indices, updates)
    # print("Updated Output:", output)
    print(input_values)
    
    onnx_path = "./models/final.onnx"
    torch.onnx.export(module, (data, indices, updates), onnx_path,
                    input_names=["data", "indices", "updates"],
                    output_names=["output"],
                    opset_version=17)

    print("ONNX model saved at:", onnx_path)

## ORT Inference

In [17]:

#INPUTS

ort_session=ort.InferenceSession("./models/scatternd.onnx")

input_values =  [np.random.randn(*x.shape).astype(np.float32) if i!=1 else np.random.randint(0,5,x.shape).astype(np.int64) for i,x in enumerate(ort_session.get_inputs())]
input_dict = dict(zip((x.name for x in ort_session.get_inputs()),input_values))

print(input_dict)

d,i,u=input_values

data = torch.from_numpy(d)
indices = torch.from_numpy(i)
updates = torch.from_numpy(u)


ort_session=ort.InferenceSession("./models/scatternd.onnx")
output1=ort_session.run(None,input_dict)
print(output1)


ort_session=ort.InferenceSession("./models/final.onnx")
output=ort_session.run(None,input_dict)
print(output)


In [18]:
from sklearn.metrics import mean_squared_error

mean_squared_error(output,output1)

np.float32(0.0)