Skip to content

The calculation results of modelProto exported by onnxscript are variable #2203

Open
@liguang-ops

Description

@liguang-ops

Recently, I used onnxscript to implement the calculation graph of the farthest point sampling. However, when I tested it, I found that the results of the direct test were accurate,

result_onnx = furthest_sampling(xyz, offset, new_offset)
    print(result_onnx)

but after converting to modelProto, the results of the onnxruntime test changed randomly, sometimes correct and sometimes wrong.

model = furthest_sampling.to_model_proto()
    model = onnx.shape_inference.infer_shapes(model)
    onnx.checker.check_model(model)
    onnx.save_model(model, "fps.onnx")


    from onnxruntime import InferenceSession
    sess = InferenceSession(model.SerializeToString(), providers=("CPUExecutionProvider",))

    got = sess.run(None, 
        {
        "xyz": xyz,
        "offset": offset, 
        "new_offset": new_offset
        })

    print(got[0].tolist())

here are the full code

from onnxscript import script
from onnxscript import opset17 as op
from onnx.helper import make_tensor
from onnxscript.onnx_types import FLOAT, INT32, INT64
import onnx
import onnxscript
import torch


custom_opset = onnxscript.values.Opset(domain="torch.onnx", version=1)


@script(custom_opset)
def furthest_sampling(xyz: FLOAT["N", 3], offset: INT32["N"], new_offset: INT32["N"]) ->INT64["N"]:
    points_num = op.Shape(xyz)[0:1] #int32
    npoint = op.Cast(new_offset[0:1], to =onnx.TensorProto.INT64) #int64
    N = op.Cast(offset[0:1], to= onnx.TensorProto.INT64) #int64
    #B = points_num / N
    B = op.Cast(op.Div(points_num, N), to=onnx.TensorProto.INT64) #int64
    batch_shape = op.Concat(B, N, op.Constant(
        value=make_tensor("con", onnx.TensorProto.INT64, [1], [3])
    ), axis=0)

    batch_xyz = op.Reshape(xyz, batch_shape) #float


    centroids_shape = op.Concat(B, npoint, axis=0) #int64
    centroids = op.ConstantOfShape(
        centroids_shape, 
        value=make_tensor("zero", onnx.TensorProto.INT64, [1], [0])) #int64

    distance_shape = op.Concat(B, N, axis=0) #int64
    distance = op.ConstantOfShape(
            distance_shape,  
            value= make_tensor("longdis", onnx.TensorProto.FLOAT, [1], [1e10])) #float

    batch_indices = op.CastLike(op.Range(start=0, limit=B, delta=1), distance_shape)
    batch_indices_2d = op.Unsqueeze(batch_indices, axes=[-1])


    barycenter = op.ReduceMean(batch_xyz, axes=[1], keepdims=1) #float
    dist = op.ReduceSum((batch_xyz - barycenter)**2, axes=[-1], keepdims=0) #float
    farthest = op.ArgMax(dist, axis=-1) #[B, 1] int64

    cond = op.Constant(value=make_tensor("true", onnx.TensorProto.BOOL, [], [1]))
    i = op.Constant(value=make_tensor("i", onnx.TensorProto.INT64, [1], [0]))
    col = op.ConstantOfShape(
            op.Shape(batch_indices_2d),
            value= [0])
    
    # centroids = op.Identity(centroids)
    # distance = op.Identity(distance)
    # i = op.Identity(i)
    # col = op.Identity(col)

    while cond:

        farthest_reshaped = op.Squeeze(farthest)
        update_indices = op.Concat(batch_indices_2d, col, axis=-1)
        centroids = op.ScatterND(centroids, update_indices, farthest_reshaped)
        #centroids[:, i] = farthest

        centroid_indice = op.Concat(batch_indices_2d, farthest, axis=-1)
        centroid = op.GatherND(batch_xyz, centroid_indice) #(B, 3)
        centroid_3d = op.Unsqueeze(centroid, axes=[1])

        dist = op.ReduceSum((batch_xyz - centroid_3d)**2, axes=[-1], keepdims=0)

        dist_A = op.CastLike(dist, distance)
        distance = op.Where(dist_A < distance, dist_A, distance)

        farthest = op.ArgMax(distance, axis=-1)

        i = op.Add(i, op.Constant(value_ints=[1]))
        cond = op.Squeeze(op.Less(i, npoint))
        col = op.Add(col, op.Constant(value_ints=[1]))


    jump = op.CastLike(op.Mul(N, batch_indices_2d), centroids)
    out = op.Add(centroids, jump)
    return op.Reshape(out, shape=op.Constant(value_ints=[-1]))

if __name__ == "__main__":
    import numpy as np
    np.random.seed(9527)

    npoint = 4
    xyz = np.random.randn(2, 8, 3).astype(np.float32)
    offset = np.arange(1, xyz.shape[0] + 1) * xyz.shape[1]
    new_offset = np.arange(1, xyz.shape[0] + 1) * npoint
    xyz = xyz.reshape(-1, 3)
    offset = offset.astype(np.int32)
    new_offset = new_offset.astype(np.int32)
    result_onnx = furthest_sampling(xyz, offset, new_offset)
    print(result_onnx)


    model = furthest_sampling.to_model_proto()
    model = onnx.shape_inference.infer_shapes(model)
    onnx.checker.check_model(model)
    onnx.save_model(model, "fps.onnx")


    from onnxruntime import InferenceSession
    sess = InferenceSession(model.SerializeToString(), providers=("CPUExecutionProvider",))

    got = sess.run(None, 
        {
        "xyz": xyz,
        "offset": offset, 
        "new_offset": new_offset
        })

    print(got[0].tolist())

i exported both two wrong and right onnx file,if it is needed,i will post them with a link soon.
Any help is greatly appreciated

Metadata

Metadata

Assignees

No one assigned

    Labels

    help wantedExtra attention is needed

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions