Skip to content

Conversation

bohnstingl
Copy link
Collaborator

@bohnstingl bohnstingl commented Nov 7, 2024

@bohnstingl bohnstingl requested a review from zou3519 as a code owner November 7, 2024 20:21
Copy link

pytorch-bot bot commented Nov 7, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/140043

Note: Links to docs will display an error until the docs builds have been completed.

❌ 2 New Failures, 1 Unrelated Failure

As of commit 9cbf5fb with merge base d100e9a (image):

NEW FAILURES - The following jobs have failed:

BROKEN TRUNK - The following job failed but were present on the merge base:

👉 Rebase onto the `viable/strict` branch to avoid these failures

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@bohnstingl
Copy link
Collaborator Author

@pytorchbot label "topic: not user facing"

@pytorch-bot pytorch-bot bot added the topic: not user facing topic category label Nov 7, 2024
@bdhirsh bdhirsh added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Nov 8, 2024
@zou3519 zou3519 requested review from ydwu4 and removed request for zou3519 November 11, 2024 17:59
@bohnstingl bohnstingl changed the title Improvements for associative_scan - Lifted arguments [associative_scan] Lifted arguments Nov 19, 2024
with discard_graph_changes(tx):
# See NOTE [unspecialize int carry with unbacked symints]
# Note: this must be run under discard graph changes.
def create_unbacked_sym_node_var(tx) -> SymNodeVariable:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don't want to do auto unspecialize for all control flow yet. Additional inputs is read-only so we don't need to create new unbacked symbols.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I reverted the handling back to the original implementation using cloning. In addition, there is now a check that all additional_inputs are TensorVariables.

def forward(self, x):
return associative_scan(self.combine_fn, x, 1)

ep = export(Foo(), (xs,), dynamic_shapes={"x": {1: dim0}})
Copy link
Contributor

@ydwu4 ydwu4 Jan 31, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we make dim 1 dynamic in this case (3, s0, 2), we then associative_scan over dim 1, the subgraph will work on a static shaped input (3, 2). If we want to make the subgraph dynamic, we'll need to scan over 0-th or 2-th dim or we change to first dim being dynamic. and in that case, i'm expecting symbol s0 to be lifted as additional_inputs.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As we discussed offline, during the export tests the symbolic shapes are not lifted and thus we don't see the int in the additional_inputs.

return associative_scan(self.combine_fn, x, 1, combine_mode="generic")

inp = torch.randn(3, 10, 2, device=torch.device("cuda"))
ep = torch.export.export(M(), (inp,))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can replace "torch.export.export" with "export", there're some test patching behind the scene if we use export.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The failure is probably caused by the randn call of the buffer (not sure why the randn is called inside vmap rather than outside at module initialization time) used inside vmap. We can put a constant tensor in the buffer.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I changed the torch.randn to torch.ones and the torch.export.export to export. However, these export tests are giving me quite some headache. The logs of the three tests are attached below:

The two last tests fail with a vmap issue. Is this because the associative_scan does not yet have a vmap implementation?

ERROR: test_export_associative_scan_lifted_buffers (__main__.TestExport.test_export_associative_scan_lifted_buffers)
----------------------------------------------------------------------
Traceback (most recent call last):
  File "/data_malta3_ssd/pytorch/torch/testing/_internal/common_utils.py", line 3120, in wrapper
    method(*args, **kwargs)
  File "/data_malta3_ssd/pytorch/test/export/test_export.py", line 6346, in test_export_associative_scan_lifted_buffers
    ep = export(M(), (inp,))
         ^^^^^^^^^^^^^^^^^^^
  File "/data_malta3_ssd/pytorch/torch/export/__init__.py", line 368, in export
    return _export(
           ^^^^^^^^
  File "/data_malta3_ssd/pytorch/torch/export/_trace.py", line 1044, in wrapper
    raise e
  File "/data_malta3_ssd/pytorch/torch/export/_trace.py", line 1017, in wrapper
    ep = fn(*args, **kwargs)
         ^^^^^^^^^^^^^^^^^^^
  File "/data_malta3_ssd/pytorch/torch/export/exported_program.py", line 117, in wrapper
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "/data_malta3_ssd/pytorch/torch/export/_trace.py", line 2079, in _export
    return _export_for_training(
           ^^^^^^^^^^^^^^^^^^^^^
  File "/data_malta3_ssd/pytorch/torch/export/_trace.py", line 1044, in wrapper
    raise e
  File "/data_malta3_ssd/pytorch/torch/export/_trace.py", line 1017, in wrapper
    ep = fn(*args, **kwargs)
         ^^^^^^^^^^^^^^^^^^^
  File "/data_malta3_ssd/pytorch/torch/export/exported_program.py", line 117, in wrapper
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "/data_malta3_ssd/pytorch/torch/export/_trace.py", line 1944, in _export_for_training
    export_artifact = export_func(  # type: ignore[operator]
                      ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data_malta3_ssd/pytorch/torch/export/_trace.py", line 1384, in _strict_export_lower_to_aten_ir
    aten_export_artifact = lower_to_aten_callback(
                           ^^^^^^^^^^^^^^^^^^^^^^^
  File "/data_malta3_ssd/pytorch/torch/export/_trace.py", line 1665, in _export_to_aten_ir_make_fx
    gm, graph_signature = transform(_make_fx_helper)(
                          ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data_malta3_ssd/pytorch/torch/export/_trace.py", line 1585, in _make_fx_helper
    gm = make_fx(
         ^^^^^^^^
  File "/data_malta3_ssd/pytorch/torch/fx/experimental/proxy_tensor.py", line 2194, in wrapped
    return make_fx_tracer.trace(f, *args)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data_malta3_ssd/pytorch/torch/fx/experimental/proxy_tensor.py", line 2132, in trace
    return self._trace_inner(f, *args)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data_malta3_ssd/pytorch/torch/fx/experimental/proxy_tensor.py", line 2103, in _trace_inner
    t = dispatch_trace(
        ^^^^^^^^^^^^^^^
  File "/data_malta3_ssd/pytorch/torch/_compile.py", line 51, in inner
    return disable_fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data_malta3_ssd/pytorch/torch/_dynamo/eval_frame.py", line 749, in _fn
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "/data_malta3_ssd/pytorch/torch/fx/experimental/proxy_tensor.py", line 1136, in dispatch_trace
    graph = tracer.trace(root, concrete_args)  # type: ignore[arg-type]
            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data_malta3_ssd/pytorch/torch/fx/experimental/proxy_tensor.py", line 1692, in trace
    res = super().trace(root, concrete_args)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data_malta3_ssd/pytorch/torch/_dynamo/eval_frame.py", line 749, in _fn
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "/data_malta3_ssd/pytorch/torch/fx/_symbolic_trace.py", line 832, in trace
    (self.create_arg(fn(*args)),),
                     ^^^^^^^^^
  File "/data_malta3_ssd/pytorch/torch/fx/experimental/proxy_tensor.py", line 1191, in wrapped
    out = f(*tensors)  # type:ignore[call-arg]
          ^^^^^^^^^^^
  File "<string>", line 1, in <lambda>
  File "/data_malta3_ssd/pytorch/torch/export/_trace.py", line 1488, in wrapped_fn
    return tuple(flat_fn(*args))
                 ^^^^^^^^^^^^^^
  File "/data_malta3_ssd/pytorch/torch/_functorch/_aot_autograd/utils.py", line 184, in flat_fn
    tree_out = fn(*args, **kwargs)
               ^^^^^^^^^^^^^^^^^^^
  File "/data_malta3_ssd/pytorch/torch/_functorch/_aot_autograd/traced_function_transforms.py", line 875, in functional_call
    out = PropagateUnbackedSymInts(mod).run(
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data_malta3_ssd/pytorch/torch/fx/interpreter.py", line 171, in run
    self.env[node] = self.run_node(node)
                     ^^^^^^^^^^^^^^^^^^^
  File "/data_malta3_ssd/pytorch/torch/fx/experimental/symbolic_shapes.py", line 6944, in run_node
    result = super().run_node(n)
             ^^^^^^^^^^^^^^^^^^^
  File "/data_malta3_ssd/pytorch/torch/fx/interpreter.py", line 234, in run_node
    return getattr(self, n.op)(n.target, args, kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data_malta3_ssd/pytorch/torch/fx/interpreter.py", line 314, in call_function
    return target(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^
  File "/data_malta3_ssd/pytorch/torch/fx/experimental/proxy_tensor.py", line 1239, in __torch_function__
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/data_malta3_ssd/pytorch/torch/fx/experimental/proxy_tensor.py", line 1286, in __torch_function__
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/data_malta3_ssd/pytorch/torch/_ops.py", line 866, in handler
    return torch._library.utils.handle_dispatch_mode(
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data_malta3_ssd/pytorch/torch/_library/utils.py", line 296, in handle_dispatch_mode
    return curr_mode.__torch_dispatch__(op_overload, overload_types, args, kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data_malta3_ssd/pytorch/torch/utils/_stats.py", line 27, in wrapper
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "/data_malta3_ssd/pytorch/torch/fx/experimental/proxy_tensor.py", line 1341, in __torch_dispatch__
    return proxy_call(self, func, self.pre_dispatch, args, kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data_malta3_ssd/pytorch/torch/fx/experimental/proxy_tensor.py", line 973, in proxy_call
    track_tensor_tree(out, proxy_out, constant=constant, tracer=tracer)
  File "/data_malta3_ssd/pytorch/torch/fx/experimental/proxy_tensor.py", line 674, in track_tensor_tree
    wrap_with_proxy(inner_res, proxy_res, constant)
  File "/data_malta3_ssd/pytorch/torch/fx/experimental/proxy_tensor.py", line 621, in wrap_with_proxy
    set_meta(proxy, e)
  File "/data_malta3_ssd/pytorch/torch/fx/experimental/proxy_tensor.py", line 491, in set_meta
    proxy.node.meta["tensor_meta"] = _extract_tensor_metadata(val)
                                     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data_malta3_ssd/pytorch/torch/fx/passes/shape_prop.py", line 55, in _extract_tensor_metadata
    if result.is_contiguous(memory_format=query_format):
       ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: NYI: querying is_contiguous inside of vmap for memory_format other than torch.contiguous_format

While executing %add_1 : [num_users=1] = call_function[target=operator.add](args = (%_add_batch_dim, %_add_batch_dim_1), kwargs = {})
Original traceback:
  File "/data_malta3_ssd/pytorch/test/export/test_export.py", line 6343, in forward
    return associative_scan(self.combine_fn, x, 1, combine_mode="generic")
  File "/data_malta3_ssd/pytorch/torch/_higher_order_ops/associative_scan.py", line 229, in associative_scan
    result_flat = generic_associative_scan(combine_fn, leaves, additional_inputs=())
  File "/data_malta3_ssd/pytorch/torch/_higher_order_ops/associative_scan.py", line 342, in generic_associative_scan
    scans = _scan(leaves)
  File "/data_malta3_ssd/pytorch/torch/_higher_order_ops/associative_scan.py", line 303, in _scan
    reduced_elems = operator(
  File "/data_malta3_ssd/pytorch/torch/_higher_order_ops/associative_scan.py", line 36, in wrap_combine_fn_flat
    combined = combine_fn(lhs, rhs)
  File "/data_malta3_ssd/pytorch/torch/_functorch/apis.py", line 202, in wrapped
    return vmap_impl(
  File "/data_malta3_ssd/pytorch/torch/_functorch/vmap.py", line 331, in vmap_impl
    return _flat_vmap(
  File "/data_malta3_ssd/pytorch/torch/_functorch/vmap.py", line 481, in _flat_vmap
    batched_outputs = func(*batched_inputs, **kwargs)
  File "/data_malta3_ssd/pytorch/test/export/test_export.py", line 6340, in combine_fn
    return (x + y) * self.buffer


To execute this test, run the following from the base repo dir:
    python test/export/test_export.py TestExport.test_export_associative_scan_lifted_buffers

This message can be suppressed by setting PYTORCH_PRINT_REPRO_ON_FAILURE=0

======================================================================
ERROR: test_export_associative_scan_symbol_dim (__main__.TestExport.test_export_associative_scan_symbol_dim)
----------------------------------------------------------------------
Traceback (most recent call last):
  File "/data_malta3_ssd/pytorch/torch/testing/_internal/common_utils.py", line 3120, in wrapper
    method(*args, **kwargs)
  File "/data_malta3_ssd/pytorch/test/export/test_export.py", line 6281, in test_export_associative_scan_symbol_dim
    ep = export(Foo(), (xs,), dynamic_shapes={"x": {1: dim1}})
         ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data_malta3_ssd/pytorch/torch/export/__init__.py", line 368, in export
    return _export(
           ^^^^^^^^
  File "/data_malta3_ssd/pytorch/torch/export/_trace.py", line 1044, in wrapper
    raise e
  File "/data_malta3_ssd/pytorch/torch/export/_trace.py", line 1017, in wrapper
    ep = fn(*args, **kwargs)
         ^^^^^^^^^^^^^^^^^^^
  File "/data_malta3_ssd/pytorch/torch/export/exported_program.py", line 117, in wrapper
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "/data_malta3_ssd/pytorch/torch/export/_trace.py", line 2079, in _export
    return _export_for_training(
           ^^^^^^^^^^^^^^^^^^^^^
  File "/data_malta3_ssd/pytorch/torch/export/_trace.py", line 1044, in wrapper
    raise e
  File "/data_malta3_ssd/pytorch/torch/export/_trace.py", line 1017, in wrapper
    ep = fn(*args, **kwargs)
         ^^^^^^^^^^^^^^^^^^^
  File "/data_malta3_ssd/pytorch/torch/export/exported_program.py", line 117, in wrapper
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "/data_malta3_ssd/pytorch/torch/export/_trace.py", line 1944, in _export_for_training
    export_artifact = export_func(  # type: ignore[operator]
                      ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data_malta3_ssd/pytorch/torch/export/_trace.py", line 1296, in _strict_export_lower_to_aten_ir
    gm_torch_level = _export_to_torch_ir(
                     ^^^^^^^^^^^^^^^^^^^^
  File "/data_malta3_ssd/pytorch/torch/export/_trace.py", line 693, in _export_to_torch_ir
    gm_torch_level, _ = torch._dynamo.export(
                        ^^^^^^^^^^^^^^^^^^^^^
  File "/data_malta3_ssd/pytorch/torch/_dynamo/eval_frame.py", line 1579, in inner
    result_traced = opt_f(*args, **kwargs)
                    ^^^^^^^^^^^^^^^^^^^^^^
  File "/data_malta3_ssd/pytorch/torch/nn/modules/module.py", line 1749, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data_malta3_ssd/pytorch/torch/nn/modules/module.py", line 1760, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data_malta3_ssd/pytorch/torch/_dynamo/eval_frame.py", line 570, in _fn
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "/data_malta3_ssd/pytorch/torch/nn/modules/module.py", line 1749, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data_malta3_ssd/pytorch/torch/nn/modules/module.py", line 1760, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data_malta3_ssd/pytorch/torch/_dynamo/convert_frame.py", line 1400, in __call__
    return self._torchdynamo_orig_callable(
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data_malta3_ssd/pytorch/torch/_dynamo/convert_frame.py", line 565, in __call__
    return _compile(
           ^^^^^^^^^
  File "/data_malta3_ssd/pytorch/torch/_dynamo/convert_frame.py", line 997, in _compile
    guarded_code = compile_inner(code, one_graph, hooks, transform)
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data_malta3_ssd/pytorch/torch/_utils_internal.py", line 95, in wrapper_function
    return function(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data_malta3_ssd/pytorch/torch/_dynamo/convert_frame.py", line 726, in compile_inner
    return _compile_inner(code, one_graph, hooks, transform)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data_malta3_ssd/pytorch/torch/_dynamo/convert_frame.py", line 760, in _compile_inner
    out_code = transform_code_object(code, transform)
               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data_malta3_ssd/pytorch/torch/_dynamo/bytecode_transformation.py", line 1404, in transform_code_object
    transformations(instructions, code_options)
  File "/data_malta3_ssd/pytorch/torch/_dynamo/convert_frame.py", line 236, in _fn
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "/data_malta3_ssd/pytorch/torch/_dynamo/convert_frame.py", line 660, in transform
    tracer = InstructionTranslator(
             ^^^^^^^^^^^^^^^^^^^^^^
  File "/data_malta3_ssd/pytorch/torch/_dynamo/symbolic_convert.py", line 2780, in __init__
    self._throw_if_in_functorch()
  File "/data_malta3_ssd/pytorch/torch/_dynamo/symbolic_convert.py", line 2896, in _throw_if_in_functorch
    unimplemented(msg)
  File "/data_malta3_ssd/pytorch/torch/_dynamo/exc.py", line 380, in unimplemented
    raise Unsupported(msg, case_name=case_name)
torch._dynamo.exc.Unsupported: If you are reaching here, it means dynamo failed for one of the following reasons:
- Calling torch.func.vmap(compiled_fn) function from eager mode is not supported. Ensure that torch.func.vmap is also wrapped within a torch.compile function. For more information, see PyTorch issue #128711.
- torch.func.vmap(fn) requires the function to be inlined by dynamo

Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information


To execute this test, run the following from the base repo dir:
    python test/export/test_export.py TestExport.test_export_associative_scan_symbol_dim

This message can be suppressed by setting PYTORCH_PRINT_REPRO_ON_FAILURE=0

======================================================================
ERROR: test_export_associative_scan_symbol_scandim (__main__.TestExport.test_export_associative_scan_symbol_scandim)
----------------------------------------------------------------------
Traceback (most recent call last):
  File "/data_malta3_ssd/pytorch/torch/testing/_internal/common_utils.py", line 3120, in wrapper
    method(*args, **kwargs)
  File "/data_malta3_ssd/pytorch/test/export/test_export.py", line 6313, in test_export_associative_scan_symbol_scandim
    ep = export(Foo(), (xs,), dynamic_shapes={"x": {1: dim1}})
         ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data_malta3_ssd/pytorch/torch/export/__init__.py", line 368, in export
    return _export(
           ^^^^^^^^
  File "/data_malta3_ssd/pytorch/torch/export/_trace.py", line 1044, in wrapper
    raise e
  File "/data_malta3_ssd/pytorch/torch/export/_trace.py", line 1017, in wrapper
    ep = fn(*args, **kwargs)
         ^^^^^^^^^^^^^^^^^^^
  File "/data_malta3_ssd/pytorch/torch/export/exported_program.py", line 117, in wrapper
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "/data_malta3_ssd/pytorch/torch/export/_trace.py", line 2079, in _export
    return _export_for_training(
           ^^^^^^^^^^^^^^^^^^^^^
  File "/data_malta3_ssd/pytorch/torch/export/_trace.py", line 1044, in wrapper
    raise e
  File "/data_malta3_ssd/pytorch/torch/export/_trace.py", line 1017, in wrapper
    ep = fn(*args, **kwargs)
         ^^^^^^^^^^^^^^^^^^^
  File "/data_malta3_ssd/pytorch/torch/export/exported_program.py", line 117, in wrapper
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "/data_malta3_ssd/pytorch/torch/export/_trace.py", line 1944, in _export_for_training
    export_artifact = export_func(  # type: ignore[operator]
                      ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data_malta3_ssd/pytorch/torch/export/_trace.py", line 1296, in _strict_export_lower_to_aten_ir
    gm_torch_level = _export_to_torch_ir(
                     ^^^^^^^^^^^^^^^^^^^^
  File "/data_malta3_ssd/pytorch/torch/export/_trace.py", line 693, in _export_to_torch_ir
    gm_torch_level, _ = torch._dynamo.export(
                        ^^^^^^^^^^^^^^^^^^^^^
  File "/data_malta3_ssd/pytorch/torch/_dynamo/eval_frame.py", line 1579, in inner
    result_traced = opt_f(*args, **kwargs)
                    ^^^^^^^^^^^^^^^^^^^^^^
  File "/data_malta3_ssd/pytorch/torch/nn/modules/module.py", line 1749, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data_malta3_ssd/pytorch/torch/nn/modules/module.py", line 1760, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data_malta3_ssd/pytorch/torch/_dynamo/eval_frame.py", line 570, in _fn
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "/data_malta3_ssd/pytorch/torch/nn/modules/module.py", line 1749, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data_malta3_ssd/pytorch/torch/nn/modules/module.py", line 1760, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data_malta3_ssd/pytorch/torch/_dynamo/convert_frame.py", line 1400, in __call__
    return self._torchdynamo_orig_callable(
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data_malta3_ssd/pytorch/torch/_dynamo/convert_frame.py", line 565, in __call__
    return _compile(
           ^^^^^^^^^
  File "/data_malta3_ssd/pytorch/torch/_dynamo/convert_frame.py", line 997, in _compile
    guarded_code = compile_inner(code, one_graph, hooks, transform)
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data_malta3_ssd/pytorch/torch/_utils_internal.py", line 95, in wrapper_function
    return function(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data_malta3_ssd/pytorch/torch/_dynamo/convert_frame.py", line 726, in compile_inner
    return _compile_inner(code, one_graph, hooks, transform)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data_malta3_ssd/pytorch/torch/_dynamo/convert_frame.py", line 760, in _compile_inner
    out_code = transform_code_object(code, transform)
               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data_malta3_ssd/pytorch/torch/_dynamo/bytecode_transformation.py", line 1404, in transform_code_object
    transformations(instructions, code_options)
  File "/data_malta3_ssd/pytorch/torch/_dynamo/convert_frame.py", line 236, in _fn
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "/data_malta3_ssd/pytorch/torch/_dynamo/convert_frame.py", line 660, in transform
    tracer = InstructionTranslator(
             ^^^^^^^^^^^^^^^^^^^^^^
  File "/data_malta3_ssd/pytorch/torch/_dynamo/symbolic_convert.py", line 2780, in __init__
    self._throw_if_in_functorch()
  File "/data_malta3_ssd/pytorch/torch/_dynamo/symbolic_convert.py", line 2896, in _throw_if_in_functorch
    unimplemented(msg)
  File "/data_malta3_ssd/pytorch/torch/_dynamo/exc.py", line 380, in unimplemented
    raise Unsupported(msg, case_name=case_name)
torch._dynamo.exc.Unsupported: If you are reaching here, it means dynamo failed for one of the following reasons:
- Calling torch.func.vmap(compiled_fn) function from eager mode is not supported. Ensure that torch.func.vmap is also wrapped within a torch.compile function. For more information, see PyTorch issue #128711.
- torch.func.vmap(fn) requires the function to be inlined by dynamo

Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information


To execute this test, run the following from the base repo dir:
    python test/export/test_export.py TestExport.test_export_associative_scan_symbol_scandim

This message can be suppressed by setting PYTORCH_PRINT_REPRO_ON_FAILURE=0

----------------------------------------------------------------------
Ran 3 tests in 2.202s

FAILED (errors=3)```

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For the first failure, it looks like an export x vmap problem: we probably have to use the point-wise mode and fix export issues. Since export doesn't require inductor compilation, i'm expecting export to succeed.

For the second failure, i'm actually not sure why we'll trigger this error because the exported model is not under vmap. Does run "python test/export/test_export.py TestExport.test_export_associative_scan_symbol_scandim" alone also fail?

Copy link
Collaborator Author

@bohnstingl bohnstingl Feb 4, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Using pointwise for the first testcase doesn't work. I think inductor is involved and in current lowering additional_inputs are not supported in inductor.

LoweringException: RuntimeError: Unable to generate code for associative_scan op, because there are lifted arguments

For the second issue python test/export/test_export.py TestExport.test_export_associative_scan_symbol_scandim does pass. In fact both of the latter tests pass with that. I.e, using:
python test/export/test_export.py TestExport.test_export_associative_scan_symbol_dim and
python test/export/test_export.py TestExport.test_export_associative_scan_symbol_scandim
both pass

Copy link
Contributor

@ydwu4 ydwu4 Feb 4, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need to update the accociative_scan's backend="eager": https://github.com/pytorch/pytorch/blob/main/torch/_higher_order_ops/associative_scan.py#L137

edit: we can do this change in next PR. Can exp fail this test and remove the exp in next PR.

Copy link
Collaborator Author

@bohnstingl bohnstingl Feb 4, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right, I've done the changes except the backend. Indeed, the backend does resolve the remaining issues.

However, the problem with switching the backend to "eager" is that the pointwise check needs to be reworked. This is because the pointwise check is invoked only in the tracing function for inductor and by switching the backend, this check is not invoked anymore.
Could you please start the CI tests?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The backend change is pursued in a separate PR here #146973

@ydwu4
Copy link
Contributor

ydwu4 commented Feb 10, 2025

@pytorchbot merge

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Feb 10, 2025
@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@pytorchmergebot
Copy link
Collaborator

Merge failed

Reason: 2 jobs have failed, first few of them are: linux-binary-manywheel / manywheel-py3_9-cuda12_4-test / test, linux-binary-manywheel / manywheel-py3_9-cuda12_6-test / test

Details for Dev Infra team Raised by workflow job

@ydwu4
Copy link
Contributor

ydwu4 commented Feb 11, 2025

@pytorchbot merge -i

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged while ignoring the following 3 checks: linux-binary-manywheel / manywheel-py3_9-cuda12_4-test / test, linux-binary-manywheel / manywheel-py3_9-cuda12_6-test / test, linux-binary-manywheel / manywheel-py3_9-cuda11_8-test / test

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

pytorchmergebot pushed a commit that referenced this pull request Feb 21, 2025
This PR fixes some issues with torch export discussed here: #140043 (comment)

However, this backend change does still not resolve the failure for specific shapes mentioned here: #137943 (comment)

Pull Request resolved: #146973
Approved by: https://github.com/ydwu4
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/trunk Trigger trunk jobs on your pull request Merged module: dynamo module: inductor no-stale open source Stale topic: not user facing topic category triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants