Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ConstraintViolationError not thrown when constraining dynamic dim to static int #122307

Open
kadeng opened this issue Mar 20, 2024 · 6 comments · May be fixed by #122913 or #123293
Open

ConstraintViolationError not thrown when constraining dynamic dim to static int #122307

kadeng opened this issue Mar 20, 2024 · 6 comments · May be fixed by #122913 or #123293
Labels
good first issue module: dynamic shapes small We think this is a small issue to fix. Consider knocking off high priority small issues triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@kadeng
Copy link
Contributor

kadeng commented Mar 20, 2024

🐛 Describe the bug

The following code section does not throw, despite the fact that x.shape[0] is constrained to a static value at the end of the function and x.shape[0] is marked as dynamic.

According to @lezcano the correct behavior would be to throw a ConstraintViolationError
See comment thread in #121808

import torch
import torch._dynamo.testing
import torch._inductor.ops_handler
import torch._inductor.utils
import torch._inductor
import torch._dynamo

def test_repro():
 
     def fn_2(x):
         # constrain in two directions
         if x.shape[0] > 5:
             return x.cos()
         if x.shape[0] < 5:
             return x * 2
         # x.shape[0] == 5 at this point
         return x.sin()

     torch._dynamo.reset()
     _x = torch.randn([5, 3, 3])
     torch._dynamo.mark_dynamic(_x, 0)
     torch.compile(backend="inductor", dynamic=None)(fn_2)(_x)

test_repro()

Versions

Collecting environment information...
PyTorch version: 2.3.0a0+git1079f56
Is debug build: False
CUDA used to build PyTorch: 12.1
ROCM used to build PyTorch: N/A

OS: CentOS Stream 9 (x86_64)
GCC version: (GCC) 11.4.1 20231218 (Red Hat 11.4.1-3)
Clang version: 17.0.6 (CentOS 17.0.6-5.el9)
CMake version: version 3.26.4
Libc version: glibc-2.34

Python version: 3.11.3 (main, May 15 2023, 15:45:52) [GCC 11.2.0] (64-bit runtime)
Python platform: Linux-5.19.0-0_fbk12_zion_11583_g0bef9520ca2b-x86_64-with-glibc2.34
Is CUDA available: True
CUDA runtime version: 12.1.105
CUDA_MODULE_LOADING set to: LAZY
GPU models and configuration:
GPU 0: NVIDIA H100
GPU 1: NVIDIA H100

Nvidia driver version: 525.105.17
cuDNN version: Could not collect
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

CPU:
Architecture: x86_64
CPU op-mode(s): 32-bit, 64-bit
Address sizes: 52 bits physical, 57 bits virtual
Byte Order: Little Endian
CPU(s): 92
On-line CPU(s) list: 0-91
Vendor ID: AuthenticAMD
Model name: AMD EPYC 9654 96-Core Processor
CPU family: 25
Model: 17
Thread(s) per core: 1
Core(s) per socket: 92
Socket(s): 1
Stepping: 1
BogoMIPS: 4792.23
Flags: fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ht syscall nx mmxext fxsr_opt pdpe1gb rdtscp lm rep_good nopl cpuid extd_apicid pni pclmulqdq ssse3 fma cx16 pcid sse4_1 sse4_2 x2apic movbe popcnt tsc_deadline_timer aes xsave avx f16c rdrand hypervisor lahf_lm cmp_legacy svm cr8_legacy abm sse4a misalignsse 3dnowprefetch osvw perfctr_core invpcid_single ssbd ibrs ibpb stibp vmmcall fsgsbase tsc_adjust bmi1 avx2 smep bmi2 erms invpcid avx512f avx512dq rdseed adx smap avx512ifma clflushopt clwb avx512cd sha_ni avx512bw avx512vl xsaveopt xsavec xgetbv1 xsaves avx512_bf16 clzero xsaveerptr wbnoinvd arat npt lbrv nrip_save tsc_scale vmcb_clean pausefilter pfthreshold v_vmsave_vmload vgif avx512vbmi umip pku ospke avx512_vbmi2 gfni vaes vpclmulqdq avx512_vnni avx512_bitalg avx512_vpopcntdq rdpid fsrm arch_capabilities
Virtualization: AMD-V
Hypervisor vendor: KVM
Virtualization type: full
L1d cache: 5.8 MiB (92 instances)
L1i cache: 5.8 MiB (92 instances)
L2 cache: 46 MiB (92 instances)
L3 cache: 1.4 GiB (92 instances)
NUMA node(s): 1
NUMA node0 CPU(s): 0-91
Vulnerability Itlb multihit: Not affected
Vulnerability L1tf: Not affected
Vulnerability Mds: Not affected
Vulnerability Meltdown: Not affected
Vulnerability Mmio stale data: Not affected
Vulnerability Retbleed: Not affected
Vulnerability Spec store bypass: Mitigation; Speculative Store Bypass disabled via prctl
Vulnerability Spectre v1: Mitigation; usercopy/swapgs barriers and __user pointer sanitization
Vulnerability Spectre v2: Vulnerable, IBPB: conditional, IBRS_FW, STIBP: disabled, RSB filling
Vulnerability Srbds: Not affected
Vulnerability Tsx async abort: Not affected

Versions of relevant libraries:
[pip3] bert-pytorch==0.0.1a4
[pip3] clip-anytorch==2.6.0
[pip3] CoCa-pytorch==0.1.0
[pip3] dalle2-pytorch==1.14.2
[pip3] ema-pytorch==0.3.2
[pip3] flake8==3.8.2
[pip3] flake8-bugbear==20.1.4
[pip3] flake8-coding==1.3.3
[pip3] flake8-comprehensions==3.3.0
[pip3] flake8-executable==2.0.4
[pip3] flake8-logging-format==0.9.0
[pip3] flake8-pyi==20.5.0
[pip3] flake8-simplify==0.19.3
[pip3] functorch==1.14.0a0+b71aa0b
[pip3] mypy==1.6.0
[pip3] mypy-extensions==1.0.0
[pip3] numpy==1.23.5
[pip3] onnx==1.15.0
[pip3] onnxruntime-gpu==1.15.1
[pip3] open-clip-torch==2.24.0
[pip3] optree==0.10.0
[pip3] pytorch-warmup==0.1.1
[pip3] rotary-embedding-torch==0.3.3
[pip3] torch==2.3.0a0+git1079f56
[pip3] torch-fidelity==0.3.0
[pip3] torch_geometric==2.4.0
[pip3] torchaudio==2.2.0.dev20240119+cu121
[pip3] torchbench==0.1
[pip3] torchmetrics==1.0.3
[pip3] torchmultimodal==0.1.0b0
[pip3] torchrec==0.5.0
[pip3] torchvision==0.18.0.dev20240119+cu121
[pip3] torchvision==0.18.0a0+6f0deb9
[pip3] torchx==0.6.0
[pip3] triton==3.0.0
[pip3] triton-nightly==2.1.0.post20240108192258
[pip3] vector_quantize_pytorch==1.12.12
[conda] bert-pytorch 0.0.1a4 dev_0
[conda] blas 1.0 mkl
[conda] clip-anytorch 2.6.0 pypi_0 pypi
[conda] coca-pytorch 0.1.0 pypi_0 pypi
[conda] dalle2-pytorch 1.14.2 pypi_0 pypi
[conda] ema-pytorch 0.3.2 pypi_0 pypi
[conda] functorch 1.14.0a0+b71aa0b pypi_0 pypi
[conda] mkl 2023.1.0 h6d00ec8_46342
[conda] mkl-include 2023.1.0 h06a4308_46342
[conda] mkl-service 2.4.0 py311h5eee18b_1
[conda] mkl_fft 1.3.6 py311ha02d727_1
[conda] mkl_random 1.2.2 py311ha02d727_1
[conda] numpy 1.23.5 pypi_0 pypi
[conda] open-clip-torch 2.24.0 pypi_0 pypi
[conda] optree 0.10.0 pypi_0 pypi
[conda] pytorch-cuda 11.8 h7e8668a_5 pytorch
[conda] pytorch-warmup 0.1.1 pypi_0 pypi
[conda] rotary-embedding-torch 0.3.3 pypi_0 pypi
[conda] torch 2.3.0a0+git1079f56 dev_0
[conda] torch-fidelity 0.3.0 pypi_0 pypi
[conda] torch-geometric 2.4.0 pypi_0 pypi
[conda] torchaudio 2.2.0.dev20240119+cu121 pypi_0 pypi
[conda] torchbench 0.1 dev_0
[conda] torchfix 0.2.0 pypi_0 pypi
[conda] torchmetrics 1.0.3 pypi_0 pypi
[conda] torchmultimodal 0.1.0b0 pypi_0 pypi
[conda] torchrec 0.5.0 pypi_0 pypi
[conda] torchvision 0.18.0a0+6f0deb9 pypi_0 pypi
[conda] torchx 0.6.0 pypi_0 pypi
[conda] triton 3.0.0 pypi_0 pypi
[conda] triton-nightly 2.1.0.post20240108192258 pypi_0 pypi
[conda] vector-quantize-pytorch 1.12.12 pypi_0 pypi

cc @ezyang

@lezcano lezcano added triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module module: dynamic shapes labels Mar 20, 2024
@ezyang
Copy link
Contributor

ezyang commented Mar 20, 2024

This is because var range refinement doesn't imply replacement, but we are only checking replacements here. Should be easy to fix.

@lezcano
Copy link
Collaborator

lezcano commented Mar 20, 2024

Marking as good first issue.

A valid solution would send a PR that:

  • Checks var_to_range and make sure that the variable is not a singleton.
  • Add a regression test.
  • Extra points if it simplifies the current logic and makes it so that it just uses var_ranges (when doing a replacement with a constant, I believe we always refine the range, so replacement should hopefully imply a singleton var_range. Adding an assertion where we perform the replacements making sure that this is the case would be great!
  • Another approach would be to register the replacement once we have a value range that is a singleton.

Feel free to directly send a PR and tag me as a reviewer :)

@ezyang ezyang added the small We think this is a small issue to fix. Consider knocking off high priority small issues label Mar 21, 2024
@Episkey0109
Copy link
Contributor

Episkey0109 commented Mar 22, 2024

Hi @lezcano ! I would like to work on this issue and I want to check if I get it right. When there is equality and the value of variable is not “statically known”, the _maybe_guard_rel will both refine ranges and set replacement, and for other relational expressions it will only refine ranges. And when track_symint checks if RelaxedUnspecConstraint is violated, it only checks check if variable is a number (s.is_number) after replacement, and we should add (possibly only use) checking about if var_to_range is singleton.

I want to clarify on the parameter val (which is a SymInt) in track_symint(source, val, constraint=None) - does it correspond to a variable or an expression in the original code? I’m a little confused by the usage of SymInt. In the following code, the first branch calls s.free_symbols and the second branch calls s.is_number and I'm confused about this and seeking clarification. Thanks!

if isinstance(constraint, StrictMinMaxConstraint):
    # try inferring the ranges of the expr s
    sym_vrs = {x: self.var_to_range.get(x, None) for x in s.free_symbols}
    if any(vr is None for vr in sym_vrs.values()):
        # some of the free symbols in s don't have ranges
        constraint_violated = True
elif isinstance(constraint, RelaxedUnspecConstraint):
    if s.is_number:
        i = int(s)
        # Don't complain about 0/1 specialization, we
        # expect to have to compile in this case anyway
        if i not in (0, 1):
            constraint_violated = True

@lezcano
Copy link
Collaborator

lezcano commented Mar 22, 2024

What about the following approach (different to the one I suggested above)

Another approach would be to register the replacement once we refine a value range into a singleton.

About that question, it very much looks that s is an expression there. If in doubt, put a conditional breakpoint there that breaks if len(sym_vrs) > 1 and run a few tests that exercise that pass, see what you find :)

@Episkey0109
Copy link
Contributor

Yes, I think this approach will work. And thanks for your clarification!

@Episkey0109
Copy link
Contributor

Episkey0109 commented Mar 22, 2024

Hi @lezcano, I think I’ve fixed the issue and I’ve added a test in test/dynamo/test_misc.py. I want to make sure that this is the right place to add a regression test (I add in this file because test_raise_guard_full_constraint locates here, and I add a test similar to this, except that the constraint is indirect).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
good first issue module: dynamic shapes small We think this is a small issue to fix. Consider knocking off high priority small issues triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
4 participants