Skip to content

Commit

Permalink
[dynamo] Explicitly fall back to eager with GraphModule with no outpu…
Browse files Browse the repository at this point in the history
…t for onnx&tvm backends (#99805)

Fixes #99437

Pull Request resolved: #99805
Approved by: https://github.com/jansel
  • Loading branch information
YJ Shi authored and pytorchmergebot committed Apr 23, 2023
1 parent 9b0b31a commit 72daade
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 0 deletions.
3 changes: 3 additions & 0 deletions torch/_dynamo/backends/onnxrt.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,9 @@ def onnxrt(gm, example_inputs, *, filename=None, provider=None):

device_type = device_from_inputs(example_inputs).type
example_outputs = gm(*example_inputs)
if len(example_outputs) == 0:
log.warning("Explicitly fall back to eager due to zero output")
return gm.forward
output_spec = [
(o.shape, o.dtype, o.layout, o.device, o.requires_grad) for o in example_outputs
]
Expand Down
4 changes: 4 additions & 0 deletions torch/_dynamo/backends/tvm.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,10 @@ def tvm(gm, example_inputs, *, scheduler=None, trials=20000):
jit_mod = torch.jit.trace(gm, example_inputs)
device = device_from_inputs(example_inputs)
shape_list = [(f"inp_{idx}", i.shape) for idx, i in enumerate(example_inputs)]
example_outputs = gm(*example_inputs)
if len(example_outputs) == 0:
log.warning("Explicitly fall back to eager due to zero output")
return gm.forward
mod, params = relay.frontend.from_pytorch(jit_mod, shape_list)
if device.type == "cuda":
dev = tvm.cuda(device.index)
Expand Down

0 comments on commit 72daade

Please sign in to comment.