Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[torch-ort-infer] Aten fallback doesn't work #139

Open
saipj opened this issue Aug 20, 2022 · 6 comments
Open

[torch-ort-infer] Aten fallback doesn't work #139

saipj opened this issue Aug 20, 2022 · 6 comments

Comments

@saipj
Copy link
Contributor

saipj commented Aug 20, 2022

Aten op doesn't fallback to native pytorch runtime as expected.

Versions:
Torch - 1.12.0
OnnxRuntime - 1.12.0
Torch-ort-infer - 1.12.0

Reproduction steps:

import torch
from torch_ort import ORTInferenceModule

def test_numpy_T(input_shape):
    class NeuralNet(torch.nn.Module):
        def __init__(self):
            super(NeuralNet, self).__init__()
        def forward(self, input):
            return input.T

    device = "cpu"
    ort_model = ORTInferenceModule(NeuralNet().to(device))

    def run_step(model, input):
        prediction = model(input)
        return prediction

    ort_input = torch.rand(input_shape, dtype=torch.float, device=device)
    ort_prediction = run_step(ort_model, ort_input)

if __name__ == "__main__":
    test_numpy_T([3, 2, 5])

Error log

Traceback (most recent call last):
File "unit_test_atenop.py", line 23, in
test_numpy_T([3, 2, 5])
File "unit_test_atenop.py", line 20, in test_numpy_T
ort_prediction = run_step(ort_model, ort_input)
File "unit_test_atenop.py", line 16, in run_step
prediction = model(input)
File "/ort_aten_fb/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl
return forward_call(*input, **kwargs)
File "/ort_aten_fb/lib/python3.8/site-packages/torch_ort/ortinferencemodule/_utils_infer.py", line 98, in _forward
return ortinferencemodule._forward_call(*inputs, **kwargs)
File "/ort_aten_fb/lib/python3.8/site-packages/torch_ort/ortinferencemodule/ortinferencemodule.py", line 107, in _forward_call
self._inference_session = onnxruntime.InferenceSession(
File "/ort_aten_fb/lib/python3.8/site-packages/onnxruntime/capi/onnxruntime_inference_collection.py", line 347, in init
self._create_inference_session(providers, provider_options, disabled_optimizers)
File "/ort_aten_fb/lib/python3.8/site-packages/onnxruntime/capi/onnxruntime_inference_collection.py", line 386, in create_inference_session
sess = C.InferenceSession(session_options, self.model_bytes, False, self.read_config_from_model)
onnxruntime.capi.onnxruntime_pybind11_state.Fail
: [ONNXRuntimeError] : 1 : FAIL : Node (ATen_0) output arg (data) type inference failed.

Tested with symbolic shape inference call from ORTModule(ref: symbolic_shape). Fails with Exception("Incomplete symbolic shape inference").

@natke
Copy link
Collaborator

natke commented Aug 25, 2022

Can you please try with the latest PyTorch nightly?

@saipj
Copy link
Contributor Author

saipj commented Aug 31, 2022

With PyTorch nightly, Aten Fallback is no longer valid since numpy_T converts to Transpose(pytorch/pytorch#79269).

@natke
Copy link
Collaborator

natke commented Sep 1, 2022

Thanks @saipj. Do you have another model we can test with?

@natke
Copy link
Collaborator

natke commented Sep 1, 2022

@saipj I had a play with this with the triu op that we talked about this morning and got it working. There were a couple of work arounds that I had to make, specifically add a custom symbolic as ORT could not infer the shape and type info for this op, as well as loading the aten cpp executor (this can be moved into ORTInferenceModule). Have a look at: https://github.com/natke/aten_snippet/blob/main/aten_model.py

@askhade
Copy link
Contributor

askhade commented Sep 12, 2022

@natke : Can you add this example in this repo? Explaining how to set the type when type inference fails in ORT.

@natke
Copy link
Collaborator

natke commented Sep 12, 2022

@askhade Yes, sure

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants