Open
Description
Recently, I used onnxscript to implement the calculation graph of the farthest point sampling. However, when I tested it, I found that the results of the direct test were accurate,
result_onnx = furthest_sampling(xyz, offset, new_offset)
print(result_onnx)
but after converting to modelProto, the results of the onnxruntime test changed randomly, sometimes correct and sometimes wrong.
model = furthest_sampling.to_model_proto()
model = onnx.shape_inference.infer_shapes(model)
onnx.checker.check_model(model)
onnx.save_model(model, "fps.onnx")
from onnxruntime import InferenceSession
sess = InferenceSession(model.SerializeToString(), providers=("CPUExecutionProvider",))
got = sess.run(None,
{
"xyz": xyz,
"offset": offset,
"new_offset": new_offset
})
print(got[0].tolist())
here are the full code
from onnxscript import script
from onnxscript import opset17 as op
from onnx.helper import make_tensor
from onnxscript.onnx_types import FLOAT, INT32, INT64
import onnx
import onnxscript
import torch
custom_opset = onnxscript.values.Opset(domain="torch.onnx", version=1)
@script(custom_opset)
def furthest_sampling(xyz: FLOAT["N", 3], offset: INT32["N"], new_offset: INT32["N"]) ->INT64["N"]:
points_num = op.Shape(xyz)[0:1] #int32
npoint = op.Cast(new_offset[0:1], to =onnx.TensorProto.INT64) #int64
N = op.Cast(offset[0:1], to= onnx.TensorProto.INT64) #int64
#B = points_num / N
B = op.Cast(op.Div(points_num, N), to=onnx.TensorProto.INT64) #int64
batch_shape = op.Concat(B, N, op.Constant(
value=make_tensor("con", onnx.TensorProto.INT64, [1], [3])
), axis=0)
batch_xyz = op.Reshape(xyz, batch_shape) #float
centroids_shape = op.Concat(B, npoint, axis=0) #int64
centroids = op.ConstantOfShape(
centroids_shape,
value=make_tensor("zero", onnx.TensorProto.INT64, [1], [0])) #int64
distance_shape = op.Concat(B, N, axis=0) #int64
distance = op.ConstantOfShape(
distance_shape,
value= make_tensor("longdis", onnx.TensorProto.FLOAT, [1], [1e10])) #float
batch_indices = op.CastLike(op.Range(start=0, limit=B, delta=1), distance_shape)
batch_indices_2d = op.Unsqueeze(batch_indices, axes=[-1])
barycenter = op.ReduceMean(batch_xyz, axes=[1], keepdims=1) #float
dist = op.ReduceSum((batch_xyz - barycenter)**2, axes=[-1], keepdims=0) #float
farthest = op.ArgMax(dist, axis=-1) #[B, 1] int64
cond = op.Constant(value=make_tensor("true", onnx.TensorProto.BOOL, [], [1]))
i = op.Constant(value=make_tensor("i", onnx.TensorProto.INT64, [1], [0]))
col = op.ConstantOfShape(
op.Shape(batch_indices_2d),
value= [0])
# centroids = op.Identity(centroids)
# distance = op.Identity(distance)
# i = op.Identity(i)
# col = op.Identity(col)
while cond:
farthest_reshaped = op.Squeeze(farthest)
update_indices = op.Concat(batch_indices_2d, col, axis=-1)
centroids = op.ScatterND(centroids, update_indices, farthest_reshaped)
#centroids[:, i] = farthest
centroid_indice = op.Concat(batch_indices_2d, farthest, axis=-1)
centroid = op.GatherND(batch_xyz, centroid_indice) #(B, 3)
centroid_3d = op.Unsqueeze(centroid, axes=[1])
dist = op.ReduceSum((batch_xyz - centroid_3d)**2, axes=[-1], keepdims=0)
dist_A = op.CastLike(dist, distance)
distance = op.Where(dist_A < distance, dist_A, distance)
farthest = op.ArgMax(distance, axis=-1)
i = op.Add(i, op.Constant(value_ints=[1]))
cond = op.Squeeze(op.Less(i, npoint))
col = op.Add(col, op.Constant(value_ints=[1]))
jump = op.CastLike(op.Mul(N, batch_indices_2d), centroids)
out = op.Add(centroids, jump)
return op.Reshape(out, shape=op.Constant(value_ints=[-1]))
if __name__ == "__main__":
import numpy as np
np.random.seed(9527)
npoint = 4
xyz = np.random.randn(2, 8, 3).astype(np.float32)
offset = np.arange(1, xyz.shape[0] + 1) * xyz.shape[1]
new_offset = np.arange(1, xyz.shape[0] + 1) * npoint
xyz = xyz.reshape(-1, 3)
offset = offset.astype(np.int32)
new_offset = new_offset.astype(np.int32)
result_onnx = furthest_sampling(xyz, offset, new_offset)
print(result_onnx)
model = furthest_sampling.to_model_proto()
model = onnx.shape_inference.infer_shapes(model)
onnx.checker.check_model(model)
onnx.save_model(model, "fps.onnx")
from onnxruntime import InferenceSession
sess = InferenceSession(model.SerializeToString(), providers=("CPUExecutionProvider",))
got = sess.run(None,
{
"xyz": xyz,
"offset": offset,
"new_offset": new_offset
})
print(got[0].tolist())
i exported both two wrong and right onnx file,if it is needed,i will post them with a link soon.
Any help is greatly appreciated