You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Found a bug for cases when the graph input is 0-dim tensors.
error[E0308]: mismatched types
--> crates/burn-import/onnx-tests/tests/onnx_tests.rs:1071:43
|
1071 | let output = model.forward(input,1.0);
| ------- ^^^ expected `Tensor<NdArray,0>`, found floating-point number
| |
| arguments to this method are incorrect
|
= note: expected struct `Tensor<NdArray,0>`
found type `{float}`
note: method defined here
--> /Users/dilshod/Projects/burn/target/debug/build/onnx-tests-9ba10161c51b05aa/out/model/unsqueeze_opset16.rs:40:12
|
40 | pub fn forward(
| ^^^^^^^
...43 | input2:Tensor<B,0>,
| --------------------
For more information about this error, try `rustc --explain E0308`.error: could not compile `onnx-tests` (test "onnx_tests") due to 1 previous error
#!/usr/bin/env python3# used to generate model: unsqueeze.onnximporttorchimporttorch.nnasnnclassModel(nn.Module):
def__init__(self):
super(Model, self).__init__()
self.axis=3defforward(self, x, scalar):
x=torch.unsqueeze(x, self.axis)
y=torch.unsqueeze(scalar, 0)
returnx, ydefmain():
# Set seed for reproducibilitytorch.manual_seed(42)
torch.set_printoptions(precision=8)
# Export to onnxmodel=Model()
model.eval()
device=torch.device("cpu")
test_input= (torch.randn(3, 4, 5, device=device),torch.tensor(1.0, device=device))
model=Model()
output=model.forward(*test_input)
torch.onnx.export(model, test_input, "unsqueeze_opset16.onnx", verbose=False, opset_version=16)
torch.onnx.export(model, test_input, "unsqueeze_opset11.onnx", verbose=False, opset_version=11)
print(f"Finished exporting model")
# Output some test data for use in the testprint(f"Test input data of ones: {test_input}")
print(f"Test input data shape of ones: {test_input[0].shape}")
# output = model.forward(test_input)print(f"Test output data shape: {output[0].shape}")
print(f"Test output: {output}")
if__name__=="__main__":
main()
The text was updated successfully, but these errors were encountered:
Yeah I've been stumbling into issues with scalars when looking at different ops to add support for recently...
Even for node inputs/outputs this might cause trouble since ONNX spec is valid for scalar tensors but Burn doesn't currently support that. We should probably discuss this internally to see how we want to handle this, because it will be an obvious obstacle to supporting as many ONNX ops (and thus models) as possible.
Found a bug for cases when the graph input is 0-dim tensors.
The text was updated successfully, but these errors were encountered: