-
Notifications
You must be signed in to change notification settings - Fork 25.6k
[associative_scan] Lifted arguments #140043
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/140043
Note: Links to docs will display an error until the docs builds have been completed. ❌ 2 New Failures, 1 Unrelated FailureAs of commit 9cbf5fb with merge base d100e9a ( NEW FAILURES - The following jobs have failed:
BROKEN TRUNK - The following job failed but were present on the merge base:👉 Rebase onto the `viable/strict` branch to avoid these failures
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
@pytorchbot label "topic: not user facing" |
WIP: export tests
with discard_graph_changes(tx): | ||
# See NOTE [unspecialize int carry with unbacked symints] | ||
# Note: this must be run under discard graph changes. | ||
def create_unbacked_sym_node_var(tx) -> SymNodeVariable: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We don't want to do auto unspecialize for all control flow yet. Additional inputs is read-only so we don't need to create new unbacked symbols.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I reverted the handling back to the original implementation using cloning. In addition, there is now a check that all additional_inputs are TensorVariables
.
test/export/test_export.py
Outdated
def forward(self, x): | ||
return associative_scan(self.combine_fn, x, 1) | ||
|
||
ep = export(Foo(), (xs,), dynamic_shapes={"x": {1: dim0}}) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If we make dim 1 dynamic in this case (3, s0, 2)
, we then associative_scan over dim 1, the subgraph will work on a static shaped input (3, 2). If we want to make the subgraph dynamic, we'll need to scan over 0-th or 2-th dim or we change to first dim being dynamic. and in that case, i'm expecting symbol s0 to be lifted as additional_inputs.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As we discussed offline, during the export tests the symbolic shapes are not lifted and thus we don't see the int
in the additional_inputs.
test/export/test_export.py
Outdated
return associative_scan(self.combine_fn, x, 1, combine_mode="generic") | ||
|
||
inp = torch.randn(3, 10, 2, device=torch.device("cuda")) | ||
ep = torch.export.export(M(), (inp,)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can replace "torch.export.export" with "export", there're some test patching behind the scene if we use export.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The failure is probably caused by the randn call of the buffer (not sure why the randn is called inside vmap rather than outside at module initialization time) used inside vmap. We can put a constant tensor in the buffer.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I changed the torch.randn
to torch.ones
and the torch.export.export
to export
. However, these export tests are giving me quite some headache. The logs of the three tests are attached below:
The two last tests fail with a vmap issue. Is this because the associative_scan does not yet have a vmap
implementation?
ERROR: test_export_associative_scan_lifted_buffers (__main__.TestExport.test_export_associative_scan_lifted_buffers)
----------------------------------------------------------------------
Traceback (most recent call last):
File "/data_malta3_ssd/pytorch/torch/testing/_internal/common_utils.py", line 3120, in wrapper
method(*args, **kwargs)
File "/data_malta3_ssd/pytorch/test/export/test_export.py", line 6346, in test_export_associative_scan_lifted_buffers
ep = export(M(), (inp,))
^^^^^^^^^^^^^^^^^^^
File "/data_malta3_ssd/pytorch/torch/export/__init__.py", line 368, in export
return _export(
^^^^^^^^
File "/data_malta3_ssd/pytorch/torch/export/_trace.py", line 1044, in wrapper
raise e
File "/data_malta3_ssd/pytorch/torch/export/_trace.py", line 1017, in wrapper
ep = fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^
File "/data_malta3_ssd/pytorch/torch/export/exported_program.py", line 117, in wrapper
return fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^
File "/data_malta3_ssd/pytorch/torch/export/_trace.py", line 2079, in _export
return _export_for_training(
^^^^^^^^^^^^^^^^^^^^^
File "/data_malta3_ssd/pytorch/torch/export/_trace.py", line 1044, in wrapper
raise e
File "/data_malta3_ssd/pytorch/torch/export/_trace.py", line 1017, in wrapper
ep = fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^
File "/data_malta3_ssd/pytorch/torch/export/exported_program.py", line 117, in wrapper
return fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^
File "/data_malta3_ssd/pytorch/torch/export/_trace.py", line 1944, in _export_for_training
export_artifact = export_func( # type: ignore[operator]
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/data_malta3_ssd/pytorch/torch/export/_trace.py", line 1384, in _strict_export_lower_to_aten_ir
aten_export_artifact = lower_to_aten_callback(
^^^^^^^^^^^^^^^^^^^^^^^
File "/data_malta3_ssd/pytorch/torch/export/_trace.py", line 1665, in _export_to_aten_ir_make_fx
gm, graph_signature = transform(_make_fx_helper)(
^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/data_malta3_ssd/pytorch/torch/export/_trace.py", line 1585, in _make_fx_helper
gm = make_fx(
^^^^^^^^
File "/data_malta3_ssd/pytorch/torch/fx/experimental/proxy_tensor.py", line 2194, in wrapped
return make_fx_tracer.trace(f, *args)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/data_malta3_ssd/pytorch/torch/fx/experimental/proxy_tensor.py", line 2132, in trace
return self._trace_inner(f, *args)
^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/data_malta3_ssd/pytorch/torch/fx/experimental/proxy_tensor.py", line 2103, in _trace_inner
t = dispatch_trace(
^^^^^^^^^^^^^^^
File "/data_malta3_ssd/pytorch/torch/_compile.py", line 51, in inner
return disable_fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/data_malta3_ssd/pytorch/torch/_dynamo/eval_frame.py", line 749, in _fn
return fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^
File "/data_malta3_ssd/pytorch/torch/fx/experimental/proxy_tensor.py", line 1136, in dispatch_trace
graph = tracer.trace(root, concrete_args) # type: ignore[arg-type]
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/data_malta3_ssd/pytorch/torch/fx/experimental/proxy_tensor.py", line 1692, in trace
res = super().trace(root, concrete_args)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/data_malta3_ssd/pytorch/torch/_dynamo/eval_frame.py", line 749, in _fn
return fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^
File "/data_malta3_ssd/pytorch/torch/fx/_symbolic_trace.py", line 832, in trace
(self.create_arg(fn(*args)),),
^^^^^^^^^
File "/data_malta3_ssd/pytorch/torch/fx/experimental/proxy_tensor.py", line 1191, in wrapped
out = f(*tensors) # type:ignore[call-arg]
^^^^^^^^^^^
File "<string>", line 1, in <lambda>
File "/data_malta3_ssd/pytorch/torch/export/_trace.py", line 1488, in wrapped_fn
return tuple(flat_fn(*args))
^^^^^^^^^^^^^^
File "/data_malta3_ssd/pytorch/torch/_functorch/_aot_autograd/utils.py", line 184, in flat_fn
tree_out = fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^
File "/data_malta3_ssd/pytorch/torch/_functorch/_aot_autograd/traced_function_transforms.py", line 875, in functional_call
out = PropagateUnbackedSymInts(mod).run(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/data_malta3_ssd/pytorch/torch/fx/interpreter.py", line 171, in run
self.env[node] = self.run_node(node)
^^^^^^^^^^^^^^^^^^^
File "/data_malta3_ssd/pytorch/torch/fx/experimental/symbolic_shapes.py", line 6944, in run_node
result = super().run_node(n)
^^^^^^^^^^^^^^^^^^^
File "/data_malta3_ssd/pytorch/torch/fx/interpreter.py", line 234, in run_node
return getattr(self, n.op)(n.target, args, kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/data_malta3_ssd/pytorch/torch/fx/interpreter.py", line 314, in call_function
return target(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^
File "/data_malta3_ssd/pytorch/torch/fx/experimental/proxy_tensor.py", line 1239, in __torch_function__
return func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File "/data_malta3_ssd/pytorch/torch/fx/experimental/proxy_tensor.py", line 1286, in __torch_function__
return func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File "/data_malta3_ssd/pytorch/torch/_ops.py", line 866, in handler
return torch._library.utils.handle_dispatch_mode(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/data_malta3_ssd/pytorch/torch/_library/utils.py", line 296, in handle_dispatch_mode
return curr_mode.__torch_dispatch__(op_overload, overload_types, args, kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/data_malta3_ssd/pytorch/torch/utils/_stats.py", line 27, in wrapper
return fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^
File "/data_malta3_ssd/pytorch/torch/fx/experimental/proxy_tensor.py", line 1341, in __torch_dispatch__
return proxy_call(self, func, self.pre_dispatch, args, kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/data_malta3_ssd/pytorch/torch/fx/experimental/proxy_tensor.py", line 973, in proxy_call
track_tensor_tree(out, proxy_out, constant=constant, tracer=tracer)
File "/data_malta3_ssd/pytorch/torch/fx/experimental/proxy_tensor.py", line 674, in track_tensor_tree
wrap_with_proxy(inner_res, proxy_res, constant)
File "/data_malta3_ssd/pytorch/torch/fx/experimental/proxy_tensor.py", line 621, in wrap_with_proxy
set_meta(proxy, e)
File "/data_malta3_ssd/pytorch/torch/fx/experimental/proxy_tensor.py", line 491, in set_meta
proxy.node.meta["tensor_meta"] = _extract_tensor_metadata(val)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/data_malta3_ssd/pytorch/torch/fx/passes/shape_prop.py", line 55, in _extract_tensor_metadata
if result.is_contiguous(memory_format=query_format):
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: NYI: querying is_contiguous inside of vmap for memory_format other than torch.contiguous_format
While executing %add_1 : [num_users=1] = call_function[target=operator.add](args = (%_add_batch_dim, %_add_batch_dim_1), kwargs = {})
Original traceback:
File "/data_malta3_ssd/pytorch/test/export/test_export.py", line 6343, in forward
return associative_scan(self.combine_fn, x, 1, combine_mode="generic")
File "/data_malta3_ssd/pytorch/torch/_higher_order_ops/associative_scan.py", line 229, in associative_scan
result_flat = generic_associative_scan(combine_fn, leaves, additional_inputs=())
File "/data_malta3_ssd/pytorch/torch/_higher_order_ops/associative_scan.py", line 342, in generic_associative_scan
scans = _scan(leaves)
File "/data_malta3_ssd/pytorch/torch/_higher_order_ops/associative_scan.py", line 303, in _scan
reduced_elems = operator(
File "/data_malta3_ssd/pytorch/torch/_higher_order_ops/associative_scan.py", line 36, in wrap_combine_fn_flat
combined = combine_fn(lhs, rhs)
File "/data_malta3_ssd/pytorch/torch/_functorch/apis.py", line 202, in wrapped
return vmap_impl(
File "/data_malta3_ssd/pytorch/torch/_functorch/vmap.py", line 331, in vmap_impl
return _flat_vmap(
File "/data_malta3_ssd/pytorch/torch/_functorch/vmap.py", line 481, in _flat_vmap
batched_outputs = func(*batched_inputs, **kwargs)
File "/data_malta3_ssd/pytorch/test/export/test_export.py", line 6340, in combine_fn
return (x + y) * self.buffer
To execute this test, run the following from the base repo dir:
python test/export/test_export.py TestExport.test_export_associative_scan_lifted_buffers
This message can be suppressed by setting PYTORCH_PRINT_REPRO_ON_FAILURE=0
======================================================================
ERROR: test_export_associative_scan_symbol_dim (__main__.TestExport.test_export_associative_scan_symbol_dim)
----------------------------------------------------------------------
Traceback (most recent call last):
File "/data_malta3_ssd/pytorch/torch/testing/_internal/common_utils.py", line 3120, in wrapper
method(*args, **kwargs)
File "/data_malta3_ssd/pytorch/test/export/test_export.py", line 6281, in test_export_associative_scan_symbol_dim
ep = export(Foo(), (xs,), dynamic_shapes={"x": {1: dim1}})
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/data_malta3_ssd/pytorch/torch/export/__init__.py", line 368, in export
return _export(
^^^^^^^^
File "/data_malta3_ssd/pytorch/torch/export/_trace.py", line 1044, in wrapper
raise e
File "/data_malta3_ssd/pytorch/torch/export/_trace.py", line 1017, in wrapper
ep = fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^
File "/data_malta3_ssd/pytorch/torch/export/exported_program.py", line 117, in wrapper
return fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^
File "/data_malta3_ssd/pytorch/torch/export/_trace.py", line 2079, in _export
return _export_for_training(
^^^^^^^^^^^^^^^^^^^^^
File "/data_malta3_ssd/pytorch/torch/export/_trace.py", line 1044, in wrapper
raise e
File "/data_malta3_ssd/pytorch/torch/export/_trace.py", line 1017, in wrapper
ep = fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^
File "/data_malta3_ssd/pytorch/torch/export/exported_program.py", line 117, in wrapper
return fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^
File "/data_malta3_ssd/pytorch/torch/export/_trace.py", line 1944, in _export_for_training
export_artifact = export_func( # type: ignore[operator]
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/data_malta3_ssd/pytorch/torch/export/_trace.py", line 1296, in _strict_export_lower_to_aten_ir
gm_torch_level = _export_to_torch_ir(
^^^^^^^^^^^^^^^^^^^^
File "/data_malta3_ssd/pytorch/torch/export/_trace.py", line 693, in _export_to_torch_ir
gm_torch_level, _ = torch._dynamo.export(
^^^^^^^^^^^^^^^^^^^^^
File "/data_malta3_ssd/pytorch/torch/_dynamo/eval_frame.py", line 1579, in inner
result_traced = opt_f(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^
File "/data_malta3_ssd/pytorch/torch/nn/modules/module.py", line 1749, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/data_malta3_ssd/pytorch/torch/nn/modules/module.py", line 1760, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/data_malta3_ssd/pytorch/torch/_dynamo/eval_frame.py", line 570, in _fn
return fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^
File "/data_malta3_ssd/pytorch/torch/nn/modules/module.py", line 1749, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/data_malta3_ssd/pytorch/torch/nn/modules/module.py", line 1760, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/data_malta3_ssd/pytorch/torch/_dynamo/convert_frame.py", line 1400, in __call__
return self._torchdynamo_orig_callable(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/data_malta3_ssd/pytorch/torch/_dynamo/convert_frame.py", line 565, in __call__
return _compile(
^^^^^^^^^
File "/data_malta3_ssd/pytorch/torch/_dynamo/convert_frame.py", line 997, in _compile
guarded_code = compile_inner(code, one_graph, hooks, transform)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/data_malta3_ssd/pytorch/torch/_utils_internal.py", line 95, in wrapper_function
return function(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^
File "/data_malta3_ssd/pytorch/torch/_dynamo/convert_frame.py", line 726, in compile_inner
return _compile_inner(code, one_graph, hooks, transform)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/data_malta3_ssd/pytorch/torch/_dynamo/convert_frame.py", line 760, in _compile_inner
out_code = transform_code_object(code, transform)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/data_malta3_ssd/pytorch/torch/_dynamo/bytecode_transformation.py", line 1404, in transform_code_object
transformations(instructions, code_options)
File "/data_malta3_ssd/pytorch/torch/_dynamo/convert_frame.py", line 236, in _fn
return fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^
File "/data_malta3_ssd/pytorch/torch/_dynamo/convert_frame.py", line 660, in transform
tracer = InstructionTranslator(
^^^^^^^^^^^^^^^^^^^^^^
File "/data_malta3_ssd/pytorch/torch/_dynamo/symbolic_convert.py", line 2780, in __init__
self._throw_if_in_functorch()
File "/data_malta3_ssd/pytorch/torch/_dynamo/symbolic_convert.py", line 2896, in _throw_if_in_functorch
unimplemented(msg)
File "/data_malta3_ssd/pytorch/torch/_dynamo/exc.py", line 380, in unimplemented
raise Unsupported(msg, case_name=case_name)
torch._dynamo.exc.Unsupported: If you are reaching here, it means dynamo failed for one of the following reasons:
- Calling torch.func.vmap(compiled_fn) function from eager mode is not supported. Ensure that torch.func.vmap is also wrapped within a torch.compile function. For more information, see PyTorch issue #128711.
- torch.func.vmap(fn) requires the function to be inlined by dynamo
Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information
To execute this test, run the following from the base repo dir:
python test/export/test_export.py TestExport.test_export_associative_scan_symbol_dim
This message can be suppressed by setting PYTORCH_PRINT_REPRO_ON_FAILURE=0
======================================================================
ERROR: test_export_associative_scan_symbol_scandim (__main__.TestExport.test_export_associative_scan_symbol_scandim)
----------------------------------------------------------------------
Traceback (most recent call last):
File "/data_malta3_ssd/pytorch/torch/testing/_internal/common_utils.py", line 3120, in wrapper
method(*args, **kwargs)
File "/data_malta3_ssd/pytorch/test/export/test_export.py", line 6313, in test_export_associative_scan_symbol_scandim
ep = export(Foo(), (xs,), dynamic_shapes={"x": {1: dim1}})
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/data_malta3_ssd/pytorch/torch/export/__init__.py", line 368, in export
return _export(
^^^^^^^^
File "/data_malta3_ssd/pytorch/torch/export/_trace.py", line 1044, in wrapper
raise e
File "/data_malta3_ssd/pytorch/torch/export/_trace.py", line 1017, in wrapper
ep = fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^
File "/data_malta3_ssd/pytorch/torch/export/exported_program.py", line 117, in wrapper
return fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^
File "/data_malta3_ssd/pytorch/torch/export/_trace.py", line 2079, in _export
return _export_for_training(
^^^^^^^^^^^^^^^^^^^^^
File "/data_malta3_ssd/pytorch/torch/export/_trace.py", line 1044, in wrapper
raise e
File "/data_malta3_ssd/pytorch/torch/export/_trace.py", line 1017, in wrapper
ep = fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^
File "/data_malta3_ssd/pytorch/torch/export/exported_program.py", line 117, in wrapper
return fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^
File "/data_malta3_ssd/pytorch/torch/export/_trace.py", line 1944, in _export_for_training
export_artifact = export_func( # type: ignore[operator]
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/data_malta3_ssd/pytorch/torch/export/_trace.py", line 1296, in _strict_export_lower_to_aten_ir
gm_torch_level = _export_to_torch_ir(
^^^^^^^^^^^^^^^^^^^^
File "/data_malta3_ssd/pytorch/torch/export/_trace.py", line 693, in _export_to_torch_ir
gm_torch_level, _ = torch._dynamo.export(
^^^^^^^^^^^^^^^^^^^^^
File "/data_malta3_ssd/pytorch/torch/_dynamo/eval_frame.py", line 1579, in inner
result_traced = opt_f(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^
File "/data_malta3_ssd/pytorch/torch/nn/modules/module.py", line 1749, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/data_malta3_ssd/pytorch/torch/nn/modules/module.py", line 1760, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/data_malta3_ssd/pytorch/torch/_dynamo/eval_frame.py", line 570, in _fn
return fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^
File "/data_malta3_ssd/pytorch/torch/nn/modules/module.py", line 1749, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/data_malta3_ssd/pytorch/torch/nn/modules/module.py", line 1760, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/data_malta3_ssd/pytorch/torch/_dynamo/convert_frame.py", line 1400, in __call__
return self._torchdynamo_orig_callable(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/data_malta3_ssd/pytorch/torch/_dynamo/convert_frame.py", line 565, in __call__
return _compile(
^^^^^^^^^
File "/data_malta3_ssd/pytorch/torch/_dynamo/convert_frame.py", line 997, in _compile
guarded_code = compile_inner(code, one_graph, hooks, transform)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/data_malta3_ssd/pytorch/torch/_utils_internal.py", line 95, in wrapper_function
return function(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^
File "/data_malta3_ssd/pytorch/torch/_dynamo/convert_frame.py", line 726, in compile_inner
return _compile_inner(code, one_graph, hooks, transform)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/data_malta3_ssd/pytorch/torch/_dynamo/convert_frame.py", line 760, in _compile_inner
out_code = transform_code_object(code, transform)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/data_malta3_ssd/pytorch/torch/_dynamo/bytecode_transformation.py", line 1404, in transform_code_object
transformations(instructions, code_options)
File "/data_malta3_ssd/pytorch/torch/_dynamo/convert_frame.py", line 236, in _fn
return fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^
File "/data_malta3_ssd/pytorch/torch/_dynamo/convert_frame.py", line 660, in transform
tracer = InstructionTranslator(
^^^^^^^^^^^^^^^^^^^^^^
File "/data_malta3_ssd/pytorch/torch/_dynamo/symbolic_convert.py", line 2780, in __init__
self._throw_if_in_functorch()
File "/data_malta3_ssd/pytorch/torch/_dynamo/symbolic_convert.py", line 2896, in _throw_if_in_functorch
unimplemented(msg)
File "/data_malta3_ssd/pytorch/torch/_dynamo/exc.py", line 380, in unimplemented
raise Unsupported(msg, case_name=case_name)
torch._dynamo.exc.Unsupported: If you are reaching here, it means dynamo failed for one of the following reasons:
- Calling torch.func.vmap(compiled_fn) function from eager mode is not supported. Ensure that torch.func.vmap is also wrapped within a torch.compile function. For more information, see PyTorch issue #128711.
- torch.func.vmap(fn) requires the function to be inlined by dynamo
Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information
To execute this test, run the following from the base repo dir:
python test/export/test_export.py TestExport.test_export_associative_scan_symbol_scandim
This message can be suppressed by setting PYTORCH_PRINT_REPRO_ON_FAILURE=0
----------------------------------------------------------------------
Ran 3 tests in 2.202s
FAILED (errors=3)```
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For the first failure, it looks like an export x vmap
problem: we probably have to use the point-wise mode and fix export issues. Since export doesn't require inductor compilation, i'm expecting export to succeed.
For the second failure, i'm actually not sure why we'll trigger this error because the exported model is not under vmap. Does run "python test/export/test_export.py TestExport.test_export_associative_scan_symbol_scandim" alone also fail?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Using pointwise
for the first testcase doesn't work. I think inductor is involved and in current lowering additional_inputs
are not supported in inductor.
LoweringException: RuntimeError: Unable to generate code for associative_scan op, because there are lifted arguments
For the second issue python test/export/test_export.py TestExport.test_export_associative_scan_symbol_scandim
does pass. In fact both of the latter tests pass with that. I.e, using:
python test/export/test_export.py TestExport.test_export_associative_scan_symbol_dim
and
python test/export/test_export.py TestExport.test_export_associative_scan_symbol_scandim
both pass
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We need to update the accociative_scan's backend="eager": https://github.com/pytorch/pytorch/blob/main/torch/_higher_order_ops/associative_scan.py#L137
edit: we can do this change in next PR. Can exp fail this test and remove the exp in next PR.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Right, I've done the changes except the backend. Indeed, the backend does resolve the remaining issues.
However, the problem with switching the backend to "eager" is that the pointwise check needs to be reworked. This is because the pointwise check is invoked only in the tracing function for inductor and by switching the backend, this check is not invoked anymore.
Could you please start the CI tests?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The backend change is pursued in a separate PR here #146973
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Merge failedReason: 2 jobs have failed, first few of them are: linux-binary-manywheel / manywheel-py3_9-cuda12_4-test / test, linux-binary-manywheel / manywheel-py3_9-cuda12_6-test / test Details for Dev Infra teamRaised by workflow job |
@pytorchbot merge -i |
Merge startedYour change will be merged while ignoring the following 3 checks: linux-binary-manywheel / manywheel-py3_9-cuda12_4-test / test, linux-binary-manywheel / manywheel-py3_9-cuda12_6-test / test, linux-binary-manywheel / manywheel-py3_9-cuda11_8-test / test Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
This PR fixes some issues with torch export discussed here: #140043 (comment) However, this backend change does still not resolve the failure for specific shapes mentioned here: #137943 (comment) Pull Request resolved: #146973 Approved by: https://github.com/ydwu4
This PR implements lifted arguments for associative_scan
cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @yf225 @chenyang78 @kadeng @muchulee8 @ColinPeppler @amjames @desertfire @chauhang @aakhundov @ydwu4