Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

AOIInductor: Dynamic Shapes Specificiaton fails for SAM #122294

Open
FabianSchuetze opened this issue Mar 20, 2024 · 4 comments
Open

AOIInductor: Dynamic Shapes Specificiaton fails for SAM #122294

FabianSchuetze opened this issue Mar 20, 2024 · 4 comments
Assignees
Labels
module: dynamic shapes oncall: export oncall: pt2 triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@FabianSchuetze
Copy link
Contributor

FabianSchuetze commented Mar 20, 2024

馃悰 Describe the bug

I am trying to aot_compile a SAM model specifying dynamic shapes fail with the following error:

[2024-03-20 09:39:49,187] [0/0] torch._guards: [ERROR] Error while creating guard:
[2024-03-20 09:39:49,187] [0/0] torch._guards: [ERROR] Name: ''
[2024-03-20 09:39:49,187] [0/0] torch._guards: [ERROR]     Source: shape_env
[2024-03-20 09:39:49,187] [0/0] torch._guards: [ERROR]     Create Function: SHAPE_ENV
[2024-03-20 09:39:49,187] [0/0] torch._guards: [ERROR]     Guard Types: None
[2024-03-20 09:39:49,187] [0/0] torch._guards: [ERROR]     Code List: None
[2024-03-20 09:39:49,187] [0/0] torch._guards: [ERROR]     Object Weakref: None
[2024-03-20 09:39:49,187] [0/0] torch._guards: [ERROR]     Guarded Class Weakref: None
[2024-03-20 09:39:49,187] [0/0] torch._guards: [ERROR] Created at:
[2024-03-20 09:39:49,187] [0/0] torch._guards: [ERROR]   File "/home/fabian/.local/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 509, in transform
[2024-03-20 09:39:49,187] [0/0] torch._guards: [ERROR]     tracer = InstructionTranslator(
[2024-03-20 09:39:49,187] [0/0] torch._guards: [ERROR]   File "/home/fabian/.local/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 2059, in __init__
[2024-03-20 09:39:49,187] [0/0] torch._guards: [ERROR]     output=OutputGraph(
[2024-03-20 09:39:49,187] [0/0] torch._guards: [ERROR]   File "/home/fabian/.local/lib/python3.10/site-packages/torch/_dynamo/output_graph.py", line 297, in __init__
[2024-03-20 09:39:49,187] [0/0] torch._guards: [ERROR]     self.init_ambient_guards()
[2024-03-20 09:39:49,187] [0/0] torch._guards: [ERROR]   File "/home/fabian/.local/lib/python3.10/site-packages/torch/_dynamo/output_graph.py", line 371, in init_ambient_guards
[2024-03-20 09:39:49,187] [0/0] torch._guards: [ERROR]     self.guards.add(ShapeEnvSource().make_guard(GuardBuilder.SHAPE_ENV))
---------------------------------------------------------------------------
AssertionError                            Traceback (most recent call last)
File /tmp/model.py:21
     19 n_labels = torch.export.Dim("n_labels", min=1, max=12)
     20 example_inputs = (img, points, labels)
---> 21 so_path = torch._export.aot_compile(
     22     model,
     23     example_inputs,
     24     dynamic_shapes={
     25                     "img":{},
     26                     "points": {1: n_labels},
     27                     "labels": {1: n_labels}},
     28     # Specify the generated shared library path
     29     options={"aot_inductor.output_path": os.path.join(os.getcwd(), "model.so")},
     30 )

File ~/.local/lib/python3.10/site-packages/torch/_export/__init__.py:1143, in aot_compile(f, args, kwargs, constraints, dynamic_shapes, options, remove_runtime_assertions, disable_constraint_solver)
   1139     constraints = _process_dynamic_shapes(f, args, kwargs, dynamic_shapes)
   1141 # We want to export to Torch IR here to utilize the pre_grad passes in
   1142 # inductor, which run on Torch IR.
-> 1143 gm = _export_to_torch_ir(
   1144     f,
   1145     args,
   1146     kwargs,
   1147     constraints,
   1148     disable_constraint_solver=disable_constraint_solver
   1149 )
   1150 flat_example_inputs = pytree.arg_tree_leaves(*args, **(kwargs or {}))
   1152 with torch.no_grad():

File ~/.local/lib/python3.10/site-packages/torch/_export/__init__.py:516, in _export_to_torch_ir(f, args, kwargs, constraints, preserve_module_call_signature, disable_constraint_solver)
    514     module_call_specs: Dict[str, Dict[str, pytree.TreeSpec]] = {}
    515     with _wrap_submodules(f, preserve_module_call_signature, module_call_specs):
--> 516         gm_torch_level, _ = torch._dynamo.export(
    517             f,
    518             constraints=constraints,
    519             assume_static_by_default=True,
    520             tracing_mode="symbolic",
    521             disable_constraint_solver=disable_constraint_solver,
    522         )(
    523             *args,
    524             **kwargs,
    525         )
    526 except (ConstraintViolationError, ValueRangeError) as e:
    527     raise UserError(UserErrorType.CONSTRAINT_VIOLATION, str(e))  # noqa: TRY200

File ~/.local/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py:1342, in export.<locals>.inner(*args, **kwargs)
   1340 # TODO(voz): We may have instances of `f` that mutate inputs, we should track sideeffects and reject.
   1341 try:
-> 1342     result_traced = opt_f(*args, **kwargs)
   1343 except ConstraintViolationError as e:
   1344     constraint_violation_error = e

File ~/.local/lib/python3.10/site-packages/torch/nn/modules/module.py:1511, in Module._wrapped_call_impl(self, *args, **kwargs)
   1509     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1510 else:
-> 1511     return self._call_impl(*args, **kwargs)

File ~/.local/lib/python3.10/site-packages/torch/nn/modules/module.py:1520, in Module._call_impl(self, *args, **kwargs)
   1515 # If we don't have any hooks, we want to skip the rest of the logic in
   1516 # this function, and just call forward.
   1517 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1518         or _global_backward_pre_hooks or _global_backward_hooks
   1519         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1520     return forward_call(*args, **kwargs)
   1522 try:
   1523     result = None

File ~/.local/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py:489, in _TorchDynamoContext.__call__.<locals>._fn(*args, **kwargs)
    487     dynamo_config_ctx.__enter__()
    488 try:
--> 489     return fn(*args, **kwargs)
    490 finally:
    491     set_eval_frame(prior)

File ~/.local/lib/python3.10/site-packages/torch/nn/modules/module.py:1511, in Module._wrapped_call_impl(self, *args, **kwargs)
   1509     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1510 else:
-> 1511     return self._call_impl(*args, **kwargs)

File ~/.local/lib/python3.10/site-packages/torch/nn/modules/module.py:1520, in Module._call_impl(self, *args, **kwargs)
   1515 # If we don't have any hooks, we want to skip the rest of the logic in
   1516 # this function, and just call forward.
   1517 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1518         or _global_backward_pre_hooks or _global_backward_hooks
   1519         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1520     return forward_call(*args, **kwargs)
   1522 try:
   1523     result = None

File ~/.local/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py:655, in catch_errors_wrapper.<locals>.catch_errors(frame, cache_entry, frame_state)
    652             return hijacked_callback(frame, cache_entry, hooks, frame_state)
    654 with compile_lock, _disable_current_modes():
--> 655     return callback(frame, cache_entry, hooks, frame_state)

File ~/.local/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py:383, in convert_frame_assert.<locals>._convert_frame_assert(frame, cache_entry, hooks, frame_state)
    370 signpost_event(
    371     "dynamo",
    372     "_convert_frame_assert._compile",
   (...)
    379     },
    380 )
    382 with config.patch(_patch_config_if_changed()):
--> 383     compiled_product = _compile(
    384         frame.f_code,
    385         frame.f_globals,
    386         frame.f_locals,
    387         frame.f_builtins,
    388         compiler_fn,
    389         one_graph,
    390         export,
    391         export_constraints,
    392         hooks,
    393         cache_size,
    394         frame,
    395         frame_state=frame_state,
    396         compile_id=compile_id,
    397     )
    398 return compiled_product

File ~/.local/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py:646, in _compile(code, globals, locals, builtins, compiler_fn, one_graph, export, export_constraints, hooks, cache_size, frame, frame_state, compile_id)
    644 with compile_context(CompileContext(compile_id)):
    645     try:
--> 646         guarded_code = compile_inner(code, one_graph, hooks, transform)
    647         return guarded_code
    648     except (
    649         Unsupported,
    650         TorchRuntimeError,
   (...)
    657         BisectValidationException,
    658     ) as e:

File ~/.local/lib/python3.10/site-packages/torch/_dynamo/utils.py:244, in dynamo_timed.<locals>.dynamo_timed_inner.<locals>.time_wrapper(*args, **kwargs)
    242 with torch.profiler.record_function(f"{key} (dynamo_timed)"):
    243     t0 = time.time()
--> 244     r = func(*args, **kwargs)
    245     time_spent = time.time() - t0
    246 compilation_time_metrics[key].append(time_spent)

File ~/.local/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py:626, in _compile.<locals>.compile_inner(code, one_graph, hooks, transform)
    624 assert output.guards is not None
    625 CleanupManager.instance[out_code] = output.cleanups
--> 626 check_fn = CheckFunctionManager(
    627     output,
    628     hooks.guard_fail_fn if hooks else None,
    629 )
    631 guarded_code = GuardedCode(out_code, check_fn.check_fn)
    633 if not output.is_empty_graph() and hooks.guard_export_fn is not None:
    634     # We should not run the guard_export_fn when Dynamo does not
    635     # generate any graph. This can happen in export when TorchDynamo
    636     # generated bytecode has some reconstruction logic for mutated
    637     # variables which can trigger TorchDynamo on the children frames but
    638     # they are benign and do not generate any new graphs.

File ~/.local/lib/python3.10/site-packages/torch/_dynamo/guards.py:1011, in CheckFunctionManager.__init__(self, output_graph, guard_fail_fn)
   1000     if (
   1001         not config.guard_nn_modules
   1002         and guard.is_nn_module()
   (...)
   1007         and (config.skip_nnmodule_hook_guards or "hooks" not in guard.name)
   1008     ):
   1009         continue
-> 1011     guard.create(builder)
   1012 self.check_fn = self.compile_check_fn(builder, guards, guard_fail_fn)
   1013 self._weakrefs.clear()

File ~/.local/lib/python3.10/site-packages/torch/_guards.py:246, in Guard.create(self, builder)
    244 def create(self, builder: GuardBuilderBase):
    245     try:
--> 246         return self.create_fn(builder, self)
    247     except Exception:
    248         log.error("Error while creating guard:\n%s", str(self).rstrip())

File ~/.local/lib/python3.10/site-packages/torch/_dynamo/guards.py:670, in GuardBuilder.SHAPE_ENV(self, guard)
    668 else:
    669     equalities_inputs = None
--> 670 guards = output_graph.shape_env.produce_guards(
    671     [a.fake for a in fs],
    672     [a.source for a in fs],
    673     constraint_inputs=constraint_inputs,
    674     equalities_inputs=equalities_inputs,
    675     source_ref=self.source_ref,
    676     # Export keeps static.
    677     ignore_static=(not self.check_fn_manager.output_graph.export),
    678 )
    679 output_graph.shape_env.freeze()
    680 for shape_guard in guards:

File ~/.local/lib/python3.10/site-packages/torch/fx/experimental/symbolic_shapes.py:2630, in ShapeEnv.produce_guards(self, placeholders, sources, source_ref, constraint_inputs, equalities_inputs, _simplified, ignore_static)
   2627     return symint.node.expr
   2629 for src1, src2 in equalities_inputs.source_pairs:
-> 2630     s1, s2 = get_symbol(src1), get_symbol(src2)
   2631     concrete_val = self.evaluate_expr(sympy.Eq(s1, s2))
   2632     if not concrete_val:

File ~/.local/lib/python3.10/site-packages/torch/fx/experimental/symbolic_shapes.py:2626, in ShapeEnv.produce_guards.<locals>.get_symbol(tensor_dim_src)
   2624 fake = placeholders[source_index[tensor_dim_src.base.name()]]
   2625 symint = fake.shape[tensor_dim_src.idx]
-> 2626 assert isinstance(symint, torch.SymInt)
   2627 return symint.node.expr

AssertionError: 

Reproduce

Consider the following model:

import os
import torch

class SAMInterface(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = torch.nn.Conv2d(3, 32, 3)

    def forward(self, img, points, labels):
        x = self.conv1(img)
        return x

with torch.no_grad():
    device = "cuda" if torch.cuda.is_available() else "cpu"
    model = SAMInterface().to(device=device)
    img=torch.randn(1, 3, 1024, 1024, device=device)
    points = torch.Tensor([[[500.0, 630.5]]])
    labels = torch.Tensor([[[1]]])
    n_labels = torch.export.Dim("n_labels", min=1, max=12)
    example_inputs = (img, points, labels)
    so_path = torch._export.aot_compile(
        model,
        example_inputs,
        dynamic_shapes={
                        "img": {},
                        "points": {1: n_labels},
                        "labels": {1: n_labels}},
        # Specify the generated shared library path
        options={"aot_inductor.output_path": os.path.join(os.getcwd(), "model.so")},
    )

The forward function of SamInterface takes three arguments img, points, and labels. For inference, the batch size is always one, but users can modify the number of labels and points from one to twelve. AOT Compilation fails with the error posted above. I tried several variations of the dynamic_shapes input, but none succeeded. How can I aot_compile such a model?

Versions

Collecting environment information...
PyTorch version: 2.2.1+cu121
Is debug build: False
CUDA used to build PyTorch: 12.1
ROCM used to build PyTorch: N/A

OS: Ubuntu 22.04.4 LTS (x86_64)
GCC version: (Ubuntu 11.4.0-1ubuntu1~22.04) 11.4.0
Clang version: Could not collect
CMake version: version 3.28.3
Libc version: glibc-2.35

Python version: 3.10.13 (main, Sep 5 2023, 06:03:44) [GCC 11.4.0] (64-bit runtime)
Python platform: Linux-6.2.0-35-generic-x86_64-with-glibc2.35
Is CUDA available: True
CUDA runtime version: Could not collect
CUDA_MODULE_LOADING set to: LAZY
GPU models and configuration: GPU 0: Quadro T1000
Nvidia driver version: 535.161.07
cuDNN version: Could not collect
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

CPU:
Architecture: x86_64
CPU op-mode(s): 32-bit, 64-bit
Address sizes: 39 bits physical, 48 bits virtual
Byte Order: Little Endian
CPU(s): 12
On-line CPU(s) list: 0-11
Vendor ID: GenuineIntel
Model name: Intel(R) Core(TM) i7-9750H CPU @ 2.60GHz
CPU family: 6
Model: 158
Thread(s) per core: 2
Core(s) per socket: 6
Socket(s): 1
Stepping: 10
CPU max MHz: 4500,0000
CPU min MHz: 800,0000
BogoMIPS: 5199.98
Flags: fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush dts acpi mmx fxsr sse sse2 ss ht tm pbe syscall nx pdpe1gb rdtscp lm constant_tsc art arch_perfmon pebs bts rep_good nopl xtopology nonstop_tsc cpuid aperfmperf pni pclmulqdq dtes64 monitor ds_cpl vmx est tm2 ssse3 sdbg fma cx16 xtpr pdcm pcid sse4_1 sse4_2 x2apic movbe popcnt tsc_deadline_timer aes xsave avx f16c rdrand lahf_lm abm 3dnowprefetch cpuid_fault epb invpcid_single pti ssbd ibrs ibpb stibp tpr_shadow vnmi flexpriority ept vpid ept_ad fsgsbase tsc_adjust sgx bmi1 avx2 smep bmi2 erms invpcid mpx rdseed adx smap clflushopt intel_pt xsaveopt xsavec xgetbv1 xsaves dtherm ida arat pln pts hwp hwp_notify hwp_act_window hwp_epp sgx_lc md_clear flush_l1d arch_capabilities
Virtualization: VT-x
L1d cache: 192 KiB (6 instances)
L1i cache: 192 KiB (6 instances)
L2 cache: 1,5 MiB (6 instances)
L3 cache: 12 MiB (1 instance)
NUMA node(s): 1
NUMA node0 CPU(s): 0-11
Vulnerability Gather data sampling: Mitigation; Microcode
Vulnerability Itlb multihit: KVM: Mitigation: VMX disabled
Vulnerability L1tf: Mitigation; PTE Inversion; VMX conditional cache flushes, SMT vulnerable
Vulnerability Mds: Mitigation; Clear CPU buffers; SMT vulnerable
Vulnerability Meltdown: Mitigation; PTI
Vulnerability Mmio stale data: Mitigation; Clear CPU buffers; SMT vulnerable
Vulnerability Retbleed: Mitigation; IBRS
Vulnerability Spec rstack overflow: Not affected
Vulnerability Spec store bypass: Mitigation; Speculative Store Bypass disabled via prctl
Vulnerability Spectre v1: Mitigation; usercopy/swapgs barriers and __user pointer sanitization
Vulnerability Spectre v2: Mitigation; IBRS, IBPB conditional, STIBP conditional, RSB filling, PBRSB-eIBRS Not affected
Vulnerability Srbds: Mitigation; Microcode
Vulnerability Tsx async abort: Not affected

Versions of relevant libraries:
[pip3] fast-pytorch-kmeans==0.2.0.1
[pip3] flake8==6.0.0
[pip3] mypy==1.4.1
[pip3] mypy-extensions==1.0.0
[pip3] numpy==1.26.4
[pip3] pytorch-triton==2.1.0+3c400e7818
[pip3] torch==2.2.1
[pip3] torch-tb-profiler==0.4.1
[pip3] torchaudio==2.1.0.dev20230714+cu121
[pip3] torchdata==0.7.0
[pip3] torchprofile==0.0.4
[pip3] torchtext==0.16.0
[pip3] torchvision==0.17.1
[pip3] triton==2.2.0
[conda] Could not collect

cc @ezyang @msaroufim @bdhirsh @anijain2305 @zou3519 @chauhang @avikchaudhuri @gmagogsfm @zhxchen17 @tugsbayasgalan @angelayi @suo @ydwu4

@ydwu4
Copy link
Contributor

ydwu4 commented Mar 20, 2024

I can repro the error locally:

(pytorch-3.10) ~/local/pytorch$ python test.py 
E0320 09:18:24.909000 139771639059456 torch/_guards.py:251] [0/0] Error while creating guard:
E0320 09:18:24.909000 139771639059456 torch/_guards.py:251] [0/0] Name: ''
E0320 09:18:24.909000 139771639059456 torch/_guards.py:251] [0/0]     Source: shape_env
E0320 09:18:24.909000 139771639059456 torch/_guards.py:251] [0/0]     Create Function: SHAPE_ENV
E0320 09:18:24.909000 139771639059456 torch/_guards.py:251] [0/0]     Guard Types: None
E0320 09:18:24.909000 139771639059456 torch/_guards.py:251] [0/0]     Code List: None
E0320 09:18:24.909000 139771639059456 torch/_guards.py:251] [0/0]     Object Weakref: None
E0320 09:18:24.909000 139771639059456 torch/_guards.py:251] [0/0]     Guarded Class Weakref: None
E0320 09:18:24.910000 139771639059456 torch/_guards.py:253] [0/0] Created at:
E0320 09:18:24.910000 139771639059456 torch/_guards.py:253] [0/0]   File "/home/yidi/local/pytorch/torch/_dynamo/convert_frame.py", line 482, in transform
E0320 09:18:24.910000 139771639059456 torch/_guards.py:253] [0/0]     tracer = InstructionTranslator(
E0320 09:18:24.910000 139771639059456 torch/_guards.py:253] [0/0]   File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 2101, in __init__
E0320 09:18:24.910000 139771639059456 torch/_guards.py:253] [0/0]     output=OutputGraph(
E0320 09:18:24.910000 139771639059456 torch/_guards.py:253] [0/0]   File "/home/yidi/local/pytorch/torch/_dynamo/output_graph.py", line 345, in __init__
E0320 09:18:24.910000 139771639059456 torch/_guards.py:253] [0/0]     self.init_ambient_guards()
E0320 09:18:24.910000 139771639059456 torch/_guards.py:253] [0/0]   File "/home/yidi/local/pytorch/torch/_dynamo/output_graph.py", line 460, in init_ambient_guards
E0320 09:18:24.910000 139771639059456 torch/_guards.py:253] [0/0]     self.guards.add(ShapeEnvSource().make_guard(GuardBuilder.SHAPE_ENV))
Traceback (most recent call last):
  File "/home/yidi/local/pytorch/torch/export/_trace.py", line 349, in _export_to_torch_ir
    gm_torch_level, _ = torch._dynamo.export(
  File "/home/yidi/local/pytorch/torch/_dynamo/eval_frame.py", line 1275, in inner
    raise constraint_violation_error
  File "/home/yidi/local/pytorch/torch/_dynamo/eval_frame.py", line 1232, in inner
    result_traced = opt_f(*args, **kwargs)
  File "/home/yidi/local/pytorch/torch/nn/modules/module.py", line 1527, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/yidi/local/pytorch/torch/nn/modules/module.py", line 1536, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/yidi/local/pytorch/torch/_dynamo/eval_frame.py", line 390, in _fn
    return fn(*args, **kwargs)
  File "/home/yidi/local/pytorch/torch/nn/modules/module.py", line 1527, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/yidi/local/pytorch/torch/nn/modules/module.py", line 1536, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/yidi/local/pytorch/torch/_dynamo/convert_frame.py", line 923, in catch_errors
    return callback(frame, cache_entry, hooks, frame_state, skip=1)
  File "/home/yidi/local/pytorch/torch/_dynamo/convert_frame.py", line 400, in _convert_frame_assert
    return _compile(
  File "/home/yidi/local/miniconda3/envs/pytorch-3.10/lib/python3.10/contextlib.py", line 79, in inner
    return func(*args, **kwds)
  File "/home/yidi/local/pytorch/torch/_dynamo/convert_frame.py", line 676, in _compile
    guarded_code = compile_inner(code, one_graph, hooks, transform)
  File "/home/yidi/local/pytorch/torch/_dynamo/utils.py", line 264, in time_wrapper
    r = func(*args, **kwargs)
  File "/home/yidi/local/pytorch/torch/_dynamo/convert_frame.py", line 634, in compile_inner
    check_fn = CheckFunctionManager(
  File "/home/yidi/local/pytorch/torch/_dynamo/guards.py", line 1039, in __init__
    guard.create(builder)
  File "/home/yidi/local/pytorch/torch/_guards.py", line 249, in create
    return self.create_fn(builder, self)
  File "/home/yidi/local/pytorch/torch/_dynamo/guards.py", line 696, in SHAPE_ENV
    guards = output_graph.shape_env.produce_guards(
  File "/home/yidi/local/pytorch/torch/fx/experimental/symbolic_shapes.py", line 3313, in produce_guards
    raise ConstraintViolationError(
torch.fx.experimental.symbolic_shapes.ConstraintViolationError: Constraints violated (n_labels)! For more information, run with TORCH_LOGS="+dynamic".
  - Not all values of n_labels = L['points'].size()[1] in the specified range n_labels <= 12 are valid because n_labels was inferred to be a constant (1).
  - Not all values of n_labels = L['labels'].size()[1] in the specified range n_labels <= 12 are valid because n_labels was inferred to be a constant (1).

Suggested fixes:
  n_labels = None  # 1

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/home/yidi/local/pytorch/test.py", line 48, in <module>
    so_path = torch._export.aot_compile(
  File "/home/yidi/local/pytorch/torch/_export/__init__.py", line 358, in aot_compile
    gm = _export_to_torch_ir(
  File "/home/yidi/local/pytorch/torch/export/_trace.py", line 361, in _export_to_torch_ir
    raise UserError(UserErrorType.CONSTRAINT_VIOLATION, str(e))  # noqa: TRY200
torch._dynamo.exc.UserError: Constraints violated (n_labels)! For more information, run with TORCH_LOGS="+dynamic".
  - Not all values of n_labels = L['points'].size()[1] in the specified range n_labels <= 12 are valid because n_labels was inferred to be a constant (1).
  - Not all values of n_labels = L['labels'].size()[1] in the specified range n_labels <= 12 are valid because n_labels was inferred to be a constant (1).

Suggested fixes:
  n_labels = None  # 1

Seems related to 0-1 specilization.

By changing the data points to

    points = torch.Tensor([[[500.0, 630.5], [500.0, 630.5]]])
    labels = torch.Tensor([[[1], [1]]])

This error is gone because the n_lables dimension is now 2 instead of 1 (treated as a special constant due to 0-1 specialization). The error message could also be improved for this case I think. cc @avikchaudhuri

@avikchaudhuri
Copy link
Contributor

Yeah I was going to suggest the same thing. @FabianSchuetze because many torch.ops treat size=1 specially, the pytorch compiler also treats them specially: to use a dynamic shape you have to select a size > 1 for the corresponding example input dimension.

@ydwu4 @angelayi The reported error is an assertion error though, and it has been fixed earlier in #121599

@ezyang
Copy link
Contributor

ezyang commented Mar 21, 2024

@avikchaudhuri A step further from #122090 we could potentially take, is to make dynamic dims as specified in export act be unbacked ints rather than backed ints, in which case it doesn't matter if you pass us a 1 or 2, we won't specialize on the 1. It could be confusing in a different way, though, because (1) you'll get errors immediately during tracing, rather than at the end, and (2) comparisons against 1 are always going to return False, even if the sample is size one lol.

@FabianSchuetze
Copy link
Contributor Author

Thanks for the comments, @avikchaudhuri and @ydwu4 .

That is indeed an applicable workaround - thanks! The consquence of this is that SAM generates to identical masks, but I can discard them easily.

@jansel jansel added triage review triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module and removed triage review labels Mar 23, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
module: dynamic shapes oncall: export oncall: pt2 triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

No branches or pull requests

7 participants