Skip to content

The calculation results of modelProto exported by onnxscript are variable #2203

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
liguang-ops opened this issue Apr 15, 2025 · 4 comments
Labels
help wanted Extra attention is needed

Comments

@liguang-ops
Copy link

liguang-ops commented Apr 15, 2025

Recently, I used onnxscript to implement the calculation graph of the farthest point sampling. However, when I tested it, I found that the results of the direct test were accurate,

result_onnx = furthest_sampling(xyz, offset, new_offset)
    print(result_onnx)

but after converting to modelProto, the results of the onnxruntime test changed randomly, sometimes correct and sometimes wrong.

model = furthest_sampling.to_model_proto()
    model = onnx.shape_inference.infer_shapes(model)
    onnx.checker.check_model(model)
    onnx.save_model(model, "fps.onnx")


    from onnxruntime import InferenceSession
    sess = InferenceSession(model.SerializeToString(), providers=("CPUExecutionProvider",))

    got = sess.run(None, 
        {
        "xyz": xyz,
        "offset": offset, 
        "new_offset": new_offset
        })

    print(got[0].tolist())

here are the full code

from onnxscript import script
from onnxscript import opset17 as op
from onnx.helper import make_tensor
from onnxscript.onnx_types import FLOAT, INT32, INT64
import onnx
import onnxscript
import torch


custom_opset = onnxscript.values.Opset(domain="torch.onnx", version=1)


@script(custom_opset)
def furthest_sampling(xyz: FLOAT["N", 3], offset: INT32["N"], new_offset: INT32["N"]) ->INT64["N"]:
    points_num = op.Shape(xyz)[0:1] #int32
    npoint = op.Cast(new_offset[0:1], to =onnx.TensorProto.INT64) #int64
    N = op.Cast(offset[0:1], to= onnx.TensorProto.INT64) #int64
    #B = points_num / N
    B = op.Cast(op.Div(points_num, N), to=onnx.TensorProto.INT64) #int64
    batch_shape = op.Concat(B, N, op.Constant(
        value=make_tensor("con", onnx.TensorProto.INT64, [1], [3])
    ), axis=0)

    batch_xyz = op.Reshape(xyz, batch_shape) #float


    centroids_shape = op.Concat(B, npoint, axis=0) #int64
    centroids = op.ConstantOfShape(
        centroids_shape, 
        value=make_tensor("zero", onnx.TensorProto.INT64, [1], [0])) #int64

    distance_shape = op.Concat(B, N, axis=0) #int64
    distance = op.ConstantOfShape(
            distance_shape,  
            value= make_tensor("longdis", onnx.TensorProto.FLOAT, [1], [1e10])) #float

    batch_indices = op.CastLike(op.Range(start=0, limit=B, delta=1), distance_shape)
    batch_indices_2d = op.Unsqueeze(batch_indices, axes=[-1])


    barycenter = op.ReduceMean(batch_xyz, axes=[1], keepdims=1) #float
    dist = op.ReduceSum((batch_xyz - barycenter)**2, axes=[-1], keepdims=0) #float
    farthest = op.ArgMax(dist, axis=-1) #[B, 1] int64

    cond = op.Constant(value=make_tensor("true", onnx.TensorProto.BOOL, [], [1]))
    i = op.Constant(value=make_tensor("i", onnx.TensorProto.INT64, [1], [0]))
    col = op.ConstantOfShape(
            op.Shape(batch_indices_2d),
            value= [0])
    
    # centroids = op.Identity(centroids)
    # distance = op.Identity(distance)
    # i = op.Identity(i)
    # col = op.Identity(col)

    while cond:

        farthest_reshaped = op.Squeeze(farthest)
        update_indices = op.Concat(batch_indices_2d, col, axis=-1)
        centroids = op.ScatterND(centroids, update_indices, farthest_reshaped)
        #centroids[:, i] = farthest

        centroid_indice = op.Concat(batch_indices_2d, farthest, axis=-1)
        centroid = op.GatherND(batch_xyz, centroid_indice) #(B, 3)
        centroid_3d = op.Unsqueeze(centroid, axes=[1])

        dist = op.ReduceSum((batch_xyz - centroid_3d)**2, axes=[-1], keepdims=0)

        dist_A = op.CastLike(dist, distance)
        distance = op.Where(dist_A < distance, dist_A, distance)

        farthest = op.ArgMax(distance, axis=-1)

        i = op.Add(i, op.Constant(value_ints=[1]))
        cond = op.Squeeze(op.Less(i, npoint))
        col = op.Add(col, op.Constant(value_ints=[1]))


    jump = op.CastLike(op.Mul(N, batch_indices_2d), centroids)
    out = op.Add(centroids, jump)
    return op.Reshape(out, shape=op.Constant(value_ints=[-1]))

if __name__ == "__main__":
    import numpy as np
    np.random.seed(9527)

    npoint = 4
    xyz = np.random.randn(2, 8, 3).astype(np.float32)
    offset = np.arange(1, xyz.shape[0] + 1) * xyz.shape[1]
    new_offset = np.arange(1, xyz.shape[0] + 1) * npoint
    xyz = xyz.reshape(-1, 3)
    offset = offset.astype(np.int32)
    new_offset = new_offset.astype(np.int32)
    result_onnx = furthest_sampling(xyz, offset, new_offset)
    print(result_onnx)


    model = furthest_sampling.to_model_proto()
    model = onnx.shape_inference.infer_shapes(model)
    onnx.checker.check_model(model)
    onnx.save_model(model, "fps.onnx")


    from onnxruntime import InferenceSession
    sess = InferenceSession(model.SerializeToString(), providers=("CPUExecutionProvider",))

    got = sess.run(None, 
        {
        "xyz": xyz,
        "offset": offset, 
        "new_offset": new_offset
        })

    print(got[0].tolist())

i exported both two wrong and right onnx file,if it is needed,i will post them with a link soon.
Any help is greatly appreciated

@liguang-ops
Copy link
Author

sorry, l lost my code environment

onnx                      1.17.0                   pypi_0    pypi
onnxruntime               1.19.2                   pypi_0    pypi
onnxscript                0.2.4                    pypi_0    pypi

@justinchuby
Copy link
Collaborator

Can you show what you mean by "correct" and "wrong"?

@liguang-ops
Copy link
Author

Can you show what you mean by "correct" and "wrong"?

yes. the correct answer is

result_onnx : [ 2  0  4  3 15 13 12 14]
result got: [2, 0, 4, 3, 15, 13, 12, 14]

the wrong answer is

result_onnx : [ 2  0  4  3 15 13 12 14]
result got: [7, 11]

or

result_onnx: [ 2  0  4  3 15 13 12 14]
result got: [4, 12]

I should point out that these outputs are all without changing the code at all. Sometimes it gives the correct output, sometimes it is wrong.
I think this function has random errors in the process of converting to model_proto, of course, it may also be that there is a problem with the code in my while loop.
Then I saved the correct model and the wrong model as onnx files, and then used onnx2python.py that comes with onnxscript to convert them and found that an error occurred after the loop ended.
here is the diff between them:
Image
As you can see, the wrong model was assigned incorrectly at the end.I don’t know if it’s caused by my non-standard writing or the library error.

@justinchuby
Copy link
Collaborator

Thanks. We will look into this. cc @gramalingam related to while loop translation

@justinchuby justinchuby added the help wanted Extra attention is needed label Apr 16, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
help wanted Extra attention is needed
Projects
None yet
Development

No branches or pull requests

2 participants