Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

torch.compile gives correct index values (if those are returned), but not the indexed values. #125387

Closed
ydshieh opened this issue May 2, 2024 · 6 comments

Comments

@ydshieh
Copy link

ydshieh commented May 2, 2024

🐛 Describe the bug

(This is not 100% duplication of #124357 (that one has a workaround by using torch._dynamo.config.guard_nn_modules=True, but that is not working for this issue.)

issue occurs for: 2.4.0.dev20240501+cu121 (possibly also earlier nightly versions)
no issue : torch 2.2 / 2.3

Description

torch.compile gives the correct index values (if the function return those index values), but if I use those indices to index into a tensor and return those selected values, the outputs (and the indices implied) stay the same after a few iterations of calling.

See the following (simple) code snippet. There is no effect in this example with or without adding torch._dynamo.config.guard_nn_modules=True,

Code snippet

import torch
import torch._dynamo as dynamo
import torch._inductor as inductor
import torch.nn as nn

dynamo.reset()


# RETURN_INDEX = True: the outputs are 0, 1, 2, 3, 4, 5 (this is expected)
# RETURN_INDEX = False: the outputs are 2, 3, 4, 5, 5, 5 (it should be 2, 3, 4, 5, 6, 7)
RETURN_INDEX = True

class ToyModel(torch.nn.Module):
    def __init__(self, return_index):
        super(ToyModel, self).__init__()
        self.value = -1
        self.return_index = return_index
        self.cache = torch.tensor([2, 3, 4, 5, 6, 7])

    def forward(self, value):
        self.value += 1
        if self.return_index:
            return self.value  # the outputs are: 0, 1, 2, 3, 4, 5
        else:
            return self.cache[self.value]  # the outputs are:  2, 3, 4, 5, 5, 5

model = ToyModel(return_index=RETURN_INDEX )
model.forward = torch.compile(model.forward, mode="reduce-overhead", fullgraph=True)

values = [6, 8, 10, 12, 13, 14]
for value in values:
    output = model.forward(value)
    print(f"output = {output}")

RETURN_INDEX = True give (0, 1, 2, 3, 4, 5):

output = 0
output = 1
output = 2
output = 3
output = 4
output = 5

RETURN_INDEX = False gives (2, 3, 4, 5, 5, 5):

output = 2
output = 3
output = 4
W0502 10:15:04.605925 140253495338816 torch/fx/experimental/symbolic_shapes.py:4542] Ignored guard s0 + 1 == 3, this could result in accuracy problems
W0502 10:15:04.605925 140253495338816 torch/fx/experimental/symbolic_shapes.py:4542] Stack (most recent call last):
W0502 10:15:04.605925 140253495338816 torch/fx/experimental/symbolic_shapes.py:4542]   File "temp.py", line 29, in <module>
W0502 10:15:04.605925 140253495338816 torch/fx/experimental/symbolic_shapes.py:4542]     output = model.forward(value)
W0502 10:15:04.605925 140253495338816 torch/fx/experimental/symbolic_shapes.py:4542]   File "/usr/local/lib/python3.8/dist-packages/torch/_dynamo/eval_frame.py", line 403, in _fn
W0502 10:15:04.605925 140253495338816 torch/fx/experimental/symbolic_shapes.py:4542]     return fn(*args, **kwargs)
W0502 10:15:04.605925 140253495338816 torch/fx/experimental/symbolic_shapes.py:4542]   File "temp.py", line 17, in forward
W0502 10:15:04.605925 140253495338816 torch/fx/experimental/symbolic_shapes.py:4542]     def forward(self, value):
W0502 10:15:04.605925 140253495338816 torch/fx/experimental/symbolic_shapes.py:4542]   File "/usr/local/lib/python3.8/dist-packages/torch/_dynamo/eval_frame.py", line 403, in _fn
W0502 10:15:04.605925 140253495338816 torch/fx/experimental/symbolic_shapes.py:4542]     return fn(*args, **kwargs)
W0502 10:15:04.605925 140253495338816 torch/fx/experimental/symbolic_shapes.py:4542]   File "/usr/local/lib/python3.8/dist-packages/torch/_dynamo/external_utils.py", line 36, in inner
W0502 10:15:04.605925 140253495338816 torch/fx/experimental/symbolic_shapes.py:4542]     return fn(*args, **kwargs)
W0502 10:15:04.605925 140253495338816 torch/fx/experimental/symbolic_shapes.py:4542]   File "/usr/local/lib/python3.8/dist-packages/torch/_functorch/aot_autograd.py", line 991, in forward
W0502 10:15:04.605925 140253495338816 torch/fx/experimental/symbolic_shapes.py:4542]     return compiled_fn(full_args)
W0502 10:15:04.605925 140253495338816 torch/fx/experimental/symbolic_shapes.py:4542]   File "/usr/local/lib/python3.8/dist-packages/torch/_functorch/_aot_autograd/runtime_wrappers.py", line 262, in runtime_wrapper
W0502 10:15:04.605925 140253495338816 torch/fx/experimental/symbolic_shapes.py:4542]     regenerated_out = gen_alias_from_base(
W0502 10:15:04.605925 140253495338816 torch/fx/experimental/symbolic_shapes.py:4542]   File "/usr/local/lib/python3.8/dist-packages/torch/_functorch/_aot_autograd/functional_utils.py", line 230, in gen_alias_from_base
W0502 10:15:04.605925 140253495338816 torch/fx/experimental/symbolic_shapes.py:4542]     out = torch._functionalize_apply_view_metas(
W0502 10:15:04.605925 140253495338816 torch/fx/experimental/symbolic_shapes.py:4542]   File "/usr/local/lib/python3.8/dist-packages/torch/fx/experimental/sym_node.py", line 355, in guard_int
W0502 10:15:04.605925 140253495338816 torch/fx/experimental/symbolic_shapes.py:4542]     r = self.shape_env.evaluate_expr(self.expr, self.hint, fx_node=self.fx_node)
W0502 10:15:04.605925 140253495338816 torch/fx/experimental/symbolic_shapes.py:4542]   File "/usr/local/lib/python3.8/dist-packages/torch/fx/experimental/recording.py", line 244, in wrapper
W0502 10:15:04.605925 140253495338816 torch/fx/experimental/symbolic_shapes.py:4542]     return fn(*args, **kwargs)
W0502 10:15:04.605925 140253495338816 torch/fx/experimental/symbolic_shapes.py:4542]   File "/usr/local/lib/python3.8/dist-packages/torch/fx/experimental/symbolic_shapes.py", line 4708, in evaluate_expr
W0502 10:15:04.605925 140253495338816 torch/fx/experimental/symbolic_shapes.py:4542]     self._check_frozen(expr, concrete_val)
W0502 10:15:04.605925 140253495338816 torch/fx/experimental/symbolic_shapes.py:4542]   File "/usr/local/lib/python3.8/dist-packages/torch/fx/experimental/symbolic_shapes.py", line 4542, in _check_frozen
W0502 10:15:04.605925 140253495338816 torch/fx/experimental/symbolic_shapes.py:4542]     log.warning("Ignored guard %s == %s, this could result in accuracy problems", expr, concrete_val, stack_info=True)
output = 5
output = 5
output = 5

Versions

Collecting environment information...
PyTorch version: 2.4.0.dev20240501+cu121
Is debug build: False
CUDA used to build PyTorch: 12.1
ROCM used to build PyTorch: N/A

OS: Ubuntu 20.04.6 LTS (x86_64)
GCC version: (Ubuntu 9.4.0-1ubuntu1~20.04.2) 9.4.0
Clang version: Could not collect
CMake version: version 3.20.3
Libc version: glibc-2.31

Python version: 3.8.10 (default, Nov 22 2023, 10:22:35)  [GCC 9.4.0] (64-bit runtime)
Python platform: Linux-5.10.214-202.855.amzn2.x86_64-x86_64-with-glibc2.29
Is CUDA available: True
CUDA runtime version: Could not collect
CUDA_MODULE_LOADING set to: LAZY
GPU models and configuration: GPU 0: NVIDIA A10G
Nvidia driver version: 535.161.08
cuDNN version: Probably one of the following:
/usr/lib/x86_64-linux-gnu/libcudnn.so.8.9.6
/usr/lib/x86_64-linux-gnu/libcudnn_adv_infer.so.8.9.6
/usr/lib/x86_64-linux-gnu/libcudnn_adv_train.so.8.9.6
/usr/lib/x86_64-linux-gnu/libcudnn_cnn_infer.so.8.9.6
/usr/lib/x86_64-linux-gnu/libcudnn_cnn_train.so.8.9.6
/usr/lib/x86_64-linux-gnu/libcudnn_ops_infer.so.8.9.6
/usr/lib/x86_64-linux-gnu/libcudnn_ops_train.so.8.9.6
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

CPU:
Architecture:                       x86_64
CPU op-mode(s):                     32-bit, 64-bit
Byte Order:                         Little Endian
Address sizes:                      48 bits physical, 48 bits virtual
CPU(s):                             16
On-line CPU(s) list:                0-15
Thread(s) per core:                 2
Core(s) per socket:                 8
Socket(s):                          1
NUMA node(s):                       1
Vendor ID:                          AuthenticAMD
CPU family:                         23
Model:                              49
Model name:                         AMD EPYC 7R32
Stepping:                           0
CPU MHz:                            3267.235
BogoMIPS:                           5600.00
Hypervisor vendor:                  KVM
Virtualization type:                full
L1d cache:                          256 KiB
L1i cache:                          256 KiB
L2 cache:                           4 MiB
L3 cache:                           32 MiB
NUMA node0 CPU(s):                  0-15
Vulnerability Gather data sampling: Not affected
Vulnerability Itlb multihit:        Not affected
Vulnerability L1tf:                 Not affected
Vulnerability Mds:                  Not affected
Vulnerability Meltdown:             Not affected
Vulnerability Mmio stale data:      Not affected
Vulnerability Retbleed:             Mitigation; untrained return thunk; SMT enabled with STIBP protection
Vulnerability Spec rstack overflow: Mitigation; safe RET
Vulnerability Spec store bypass:    Mitigation; Speculative Store Bypass disabled via prctl and seccomp
Vulnerability Spectre v1:           Mitigation; usercopy/swapgs barriers and __user pointer sanitization
Vulnerability Spectre v2:           Mitigation; Retpolines, IBPB conditional, STIBP always-on, RSB filling
Vulnerability Srbds:                Not affected
Vulnerability Tsx async abort:      Not affected
Flags:                              fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ht syscall nx mmxext fxsr_opt pdpe1gb rdtscp lm constant_tsc rep_good nopl nonstop_tsc cpuid extd_apicid aperfmperf tsc_known_freq pni pclmulqdq ssse3 fma cx16 sse4_1 sse4_2 movbe popcnt aes xsave avx f16c rdrand hypervisor lahf_lm cmp_legacy cr8_legacy abm sse4a misalignsse 3dnowprefetch topoext ssbd ibrs ibpb stibp vmmcall fsgsbase bmi1 avx2 smep bmi2 rdseed adx smap clflushopt clwb sha_ni xsaveopt xsavec xgetbv1 clzero xsaveerptr rdpru wbnoinvd arat npt nrip_save rdpid

Versions of relevant libraries:
[pip3] intel-extension-for-pytorch==2.2.0
[pip3] mypy-extensions==1.0.0
[pip3] natten==0.15.1+torch220cu118
[pip3] numpy==1.24.3
[pip3] onnx==1.16.0
[pip3] onnxconverter-common==1.13.0
[pip3] onnxruntime==1.17.3
[pip3] onnxruntime-tools==1.7.0
[pip3] pytorch-triton==3.0.0+45fff310c8
[pip3] tf2onnx==1.16.1
[pip3] torch==2.4.0.dev20240501+cu121
[pip3] torchaudio==2.2.0.dev20240501+cu121
[pip3] torchvision==0.19.0.dev20240501+cu121
[pip3] triton==2.2.0
[conda] Could not collect

cc @ezyang @gchanan @zou3519 @kadeng @msaroufim @bdhirsh @anijain2305 @chauhang @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @chenyang78

(just like #124357)

@ydshieh
Copy link
Author

ydshieh commented May 2, 2024

while using 2.4.0.dev20240501+cu121 (the one for which the issue exists), if I put both cases in the same script like

RETURN_INDEX = True
model = ToyModel(return_index=RETURN_INDEX )
model.forward = torch.compile(model.forward, mode="reduce-overhead", fullgraph=True)

values = [6, 8, 10, 12, 13, 14]
for value in values:
    output = model.forward(value)
    print(f"output = {output}")



RETURN_INDEX = False
model = ToyModel(return_index=RETURN_INDEX )
model.forward = torch.compile(model.forward, mode="reduce-overhead", fullgraph=True)

values = [6, 8, 10, 12, 13, 14]
for value in values:
    output = model.forward(value)
    print(f"output = {output}")

the outputs are correct

output = 0
output = 1
output = 2
output = 3
output = 4
output = 5

output = 2
output = 3
output = 4
output = 5
output = 6
output = 7

@ezyang
Copy link
Contributor

ezyang commented May 3, 2024

This is almost definitely related to the "ignored guard" warning

@ydshieh
Copy link
Author

ydshieh commented May 3, 2024

Just to make it more visible: this seems a regression

issue occurs for: 2.4.0.dev20240501+cu121 (possibly also earlier nightly versions)
no issue : torch 2.2 / 2.3

@lezcano
Copy link
Collaborator

lezcano commented May 6, 2024

The guard comes from the autogen code:

  return at::native::as_strided_tensorimpl(self, C10_AS_INTARRAYREF_SLOW(size), C10_AS_INTARRAYREF_SLOW(stride), storage_offset.has_value() ? ::std::make_optional(storage_offset->guard_int(__FILE__, __LINE__)) : ::std::nullopt);

This fails on runtime_wrapper.py, line 262. This comes from #121007 added this @ysiraichi.

The issue is that the current implementation in gen_alias_from_base expects apply_view_metas to fail when run with dynamic shapes. Now, since #113921 we specialise symints, so we are not getting a hard error, and instead we silently fail.

An easy fix for this would be to hard error instead of warning. This indeed fixes the issue.
An even better fix would be to expose a method that detects whether a functional tensor has dynamic shapes, and remove the try-except hack from #121007

@ysiraichi Mind taking a look? If you think it'd take you quite long ping me and I can submit a fix.

@ysiraichi
Copy link
Collaborator

@ydshieh With #124948, you won't be seeing this issue by default, since it's disabled.

However, enabling that dynamo configuration does end up in the same issue. I think that the fix is, as @lezcano said, to detect whether a FunctionalTensor has a view operation that uses a SymInt. This would also possibly make the regressions go away, making it the default behavior.

I will work on it.

ysiraichi added a commit that referenced this issue May 9, 2024
Fix: #125387

This PR helps keep track of whether an instantiated `ViewMeta` has symbolic values as
input or not. This is used for checking whether we use the AOTAutograd `ViewMeta`-replay
execution path, e.g. it doesn't support tensors that have `ViewMeta` with symbolic inputs.

In summary, the changes are:

- Add the field `ViewMeta::has_symbolic_inputs` and make it a required constructor
parameter
- Add the field `FunctionalTensorWrapper::is_symbolic_` and the method
`FunctionalTensorWrapper::maybe_mark_symbolic`
    - Marks a `FunctionalTensorWrapper` as symbolic iff any of its `ViewMeta` have
    symbolic inputs
- Add the plumbing of `FunctionalTensorWrapper::is_symbolic` to the Python API
- Codegen the computation of `ViewMeta::has_symbolic_inputs` for each view operation
- Use the AOTAutograd `ViewMeta`-replay path if:
    - `target_functional_tensor` is not `None`; and
    - `target_functional_tensor` is not symbolic (instead of using a functorch config)

[ghstack-poisoned]
ysiraichi added a commit that referenced this issue May 9, 2024
Fix: #125387

This PR helps keep track of whether an instantiated `ViewMeta` has symbolic values as
input or not. This is used for checking whether we use the AOTAutograd `ViewMeta`-replay
execution path, e.g. it doesn't support tensors that have `ViewMeta` with symbolic inputs.

In summary, the changes are:

- Add the field `ViewMeta::has_symbolic_inputs` and make it a required constructor
parameter
- Add the field `FunctionalTensorWrapper::is_symbolic_` and the method
`FunctionalTensorWrapper::maybe_mark_symbolic`
    - Marks a `FunctionalTensorWrapper` as symbolic iff any of its `ViewMeta` have
    symbolic inputs
- Add the plumbing of `FunctionalTensorWrapper::is_symbolic` to the Python API
- Codegen the computation of `ViewMeta::has_symbolic_inputs` for each view operation
- Use the AOTAutograd `ViewMeta`-replay path if:
    - `target_functional_tensor` is not `None`; and
    - `target_functional_tensor` is not symbolic (instead of using a functorch config)

ghstack-source-id: 5e415c7ae4d6e1f5646a2787eb718b532ea9d640
Pull Request resolved: #125876
ysiraichi added a commit that referenced this issue May 9, 2024
Fix: #125387

This PR helps keep track of whether an instantiated `ViewMeta` has symbolic values as
input or not. This is used for checking whether we use the AOTAutograd `ViewMeta`-replay
execution path, e.g. it doesn't support tensors that have `ViewMeta` with symbolic inputs.

In summary, the changes are:

- Add the field `ViewMeta::has_symbolic_inputs` and make it a required constructor
parameter
- Add the field `FunctionalTensorWrapper::is_symbolic_` and the method
`FunctionalTensorWrapper::maybe_mark_symbolic`
    - Marks a `FunctionalTensorWrapper` as symbolic iff any of its `ViewMeta` have
    symbolic inputs
- Add the plumbing of `FunctionalTensorWrapper::is_symbolic` to the Python API
- Codegen the computation of `ViewMeta::has_symbolic_inputs` for each view operation
- Use the AOTAutograd `ViewMeta`-replay path if:
    - `target_functional_tensor` is not `None`; and
    - `target_functional_tensor` is not symbolic (instead of using a functorch config)

cc bdhirsh miladm lezcano 

[ghstack-poisoned]
ysiraichi added a commit that referenced this issue May 9, 2024
Fix: #125387

This PR helps keep track of whether an instantiated `ViewMeta` has symbolic values as
input or not. This is used for checking whether we use the AOTAutograd `ViewMeta`-replay
execution path, e.g. it doesn't support tensors that have `ViewMeta` with symbolic inputs.

In summary, the changes are:

- Add the field `ViewMeta::has_symbolic_inputs` and make it a required constructor
parameter
- Add the field `FunctionalTensorWrapper::is_symbolic_` and the method
`FunctionalTensorWrapper::maybe_mark_symbolic`
    - Marks a `FunctionalTensorWrapper` as symbolic iff any of its `ViewMeta` have
    symbolic inputs
- Add the plumbing of `FunctionalTensorWrapper::is_symbolic` to the Python API
- Codegen the computation of `ViewMeta::has_symbolic_inputs` for each view operation
- Use the AOTAutograd `ViewMeta`-replay path if:
    - `target_functional_tensor` is not `None`; and
    - `target_functional_tensor` is not symbolic (instead of using a functorch config)

ghstack-source-id: 5acff826e8e26c407af1e9ea999846c83e2c7787
Pull Request resolved: #125876
ysiraichi added a commit that referenced this issue May 10, 2024
Fix: #125387

This PR helps keep track of whether an instantiated `ViewMeta` has symbolic values as
input or not. This is used for checking whether we use the AOTAutograd `ViewMeta`-replay
execution path, e.g. it doesn't support tensors that have `ViewMeta` with symbolic inputs.

In summary, the changes are:

- Add the field `ViewMeta::has_symbolic_inputs` and make it a required constructor
parameter
- Add the field `FunctionalTensorWrapper::is_symbolic_` and the method
`FunctionalTensorWrapper::maybe_mark_symbolic`
    - Marks a `FunctionalTensorWrapper` as symbolic iff any of its `ViewMeta` have
    symbolic inputs
- Add the plumbing of `FunctionalTensorWrapper::is_symbolic` to the Python API
- Codegen the computation of `ViewMeta::has_symbolic_inputs` for each view operation
- Use the AOTAutograd `ViewMeta`-replay path if:
    - `target_functional_tensor` is not `None`; and
    - `target_functional_tensor` is not symbolic (instead of using a functorch config)

cc bdhirsh miladm lezcano 

[ghstack-poisoned]
ysiraichi added a commit that referenced this issue May 10, 2024
Fix: #125387

This PR helps keep track of whether an instantiated `ViewMeta` has symbolic values as
input or not. This is used for checking whether we use the AOTAutograd `ViewMeta`-replay
execution path, e.g. it doesn't support tensors that have `ViewMeta` with symbolic inputs.

In summary, the changes are:

- Add the field `ViewMeta::has_symbolic_inputs` and make it a required constructor
parameter
- Add the field `FunctionalTensorWrapper::is_symbolic_` and the method
`FunctionalTensorWrapper::maybe_mark_symbolic`
    - Marks a `FunctionalTensorWrapper` as symbolic iff any of its `ViewMeta` have
    symbolic inputs
- Add the plumbing of `FunctionalTensorWrapper::is_symbolic` to the Python API
- Codegen the computation of `ViewMeta::has_symbolic_inputs` for each view operation
- Use the AOTAutograd `ViewMeta`-replay path if:
    - `target_functional_tensor` is not `None`; and
    - `target_functional_tensor` is not symbolic (instead of using a functorch config)

ghstack-source-id: 1bae9c7991ee73c494c23765973e91bb1e0b69df
Pull Request resolved: #125876
ysiraichi added a commit that referenced this issue May 10, 2024
Fix: #125387

This PR helps keep track of whether an instantiated `ViewMeta` has symbolic values as
input or not. This is used for checking whether we use the AOTAutograd `ViewMeta`-replay
execution path, e.g. it doesn't support tensors that have `ViewMeta` with symbolic inputs.

In summary, the changes are:

- Add the field `ViewMeta::has_symbolic_inputs` and make it a required constructor
parameter
- Add the field `FunctionalTensorWrapper::is_symbolic_` and the method
`FunctionalTensorWrapper::maybe_mark_symbolic`
    - Marks a `FunctionalTensorWrapper` as symbolic iff any of its `ViewMeta` have
    symbolic inputs
- Add the plumbing of `FunctionalTensorWrapper::is_symbolic` to the Python API
- Codegen the computation of `ViewMeta::has_symbolic_inputs` for each view operation
- Use the AOTAutograd `ViewMeta`-replay path if:
    - `target_functional_tensor` is not `None`; and
    - `target_functional_tensor` is not symbolic (instead of using a functorch config)

cc bdhirsh miladm lezcano 

[ghstack-poisoned]
ysiraichi added a commit that referenced this issue May 10, 2024
Fix: #125387

This PR helps keep track of whether an instantiated `ViewMeta` has symbolic values as
input or not. This is used for checking whether we use the AOTAutograd `ViewMeta`-replay
execution path, e.g. it doesn't support tensors that have `ViewMeta` with symbolic inputs.

In summary, the changes are:

- Add the field `ViewMeta::has_symbolic_inputs` and make it a required constructor
parameter
- Add the field `FunctionalTensorWrapper::is_symbolic_` and the method
`FunctionalTensorWrapper::maybe_mark_symbolic`
    - Marks a `FunctionalTensorWrapper` as symbolic iff any of its `ViewMeta` have
    symbolic inputs
- Add the plumbing of `FunctionalTensorWrapper::is_symbolic` to the Python API
- Codegen the computation of `ViewMeta::has_symbolic_inputs` for each view operation
- Use the AOTAutograd `ViewMeta`-replay path if:
    - `target_functional_tensor` is not `None`; and
    - `target_functional_tensor` is not symbolic (instead of using a functorch config)

ghstack-source-id: ffcea70e7ba134d44263434889e6fe487b81d7fc
Pull Request resolved: #125876
ysiraichi added a commit that referenced this issue May 10, 2024
Fix: #125387

This PR helps keep track of whether an instantiated `ViewMeta` has symbolic values as
input or not. This is used for checking whether we use the AOTAutograd `ViewMeta`-replay
execution path, e.g. it doesn't support tensors that have `ViewMeta` with symbolic inputs.

In summary, the changes are:

- Add the field `ViewMeta::has_symbolic_inputs` and make it a required constructor
parameter
- Add the field `FunctionalTensorWrapper::is_symbolic_` and the method
`FunctionalTensorWrapper::maybe_mark_symbolic`
    - Marks a `FunctionalTensorWrapper` as symbolic iff any of its `ViewMeta` have
    symbolic inputs
- Add the plumbing of `FunctionalTensorWrapper::is_symbolic` to the Python API
- Codegen the computation of `ViewMeta::has_symbolic_inputs` for each view operation
- Use the AOTAutograd `ViewMeta`-replay path if:
    - `target_functional_tensor` is not `None`; and
    - `target_functional_tensor` is not symbolic (instead of using a functorch config)

ghstack-source-id: e638511986ed985db72e03567a93e3f0be9b826c
Pull Request resolved: #125876
ysiraichi added a commit that referenced this issue May 10, 2024
…puts."


Fix: #125387

This PR helps keep track of whether an instantiated `ViewMeta` has symbolic values as
input or not. This is used for checking whether we use the AOTAutograd `ViewMeta`-replay
execution path, e.g. it doesn't support tensors that have `ViewMeta` with symbolic inputs.

In summary, the changes are:

- Add the field `ViewMeta::has_symbolic_inputs` and make it a required constructor
parameter
- Add the field `FunctionalTensorWrapper::is_symbolic_` and the method
`FunctionalTensorWrapper::maybe_mark_symbolic`
    - Marks a `FunctionalTensorWrapper` as symbolic iff any of its `ViewMeta` have
    symbolic inputs
- Add the plumbing of `FunctionalTensorWrapper::is_symbolic` to the Python API
- Codegen the computation of `ViewMeta::has_symbolic_inputs` for each view operation
- Use the AOTAutograd `ViewMeta`-replay path if:
    - `target_functional_tensor` is not `None`; and
    - `target_functional_tensor` is not symbolic (instead of using a functorch config)

cc bdhirsh miladm lezcano 

[ghstack-poisoned]
ysiraichi added a commit that referenced this issue May 10, 2024
Fix: #125387

This PR helps keep track of whether an instantiated `ViewMeta` has symbolic values as
input or not. This is used for checking whether we use the AOTAutograd `ViewMeta`-replay
execution path, e.g. it doesn't support tensors that have `ViewMeta` with symbolic inputs.

In summary, the changes are:

- Add the field `ViewMeta::has_symbolic_inputs` and make it a required constructor
parameter
- Add the field `FunctionalTensorWrapper::is_symbolic_` and the method
`FunctionalTensorWrapper::maybe_mark_symbolic`
    - Marks a `FunctionalTensorWrapper` as symbolic iff any of its `ViewMeta` have
    symbolic inputs
- Add the plumbing of `FunctionalTensorWrapper::is_symbolic` to the Python API
- Codegen the computation of `ViewMeta::has_symbolic_inputs` for each view operation
- Use the AOTAutograd `ViewMeta`-replay path if:
    - `target_functional_tensor` is not `None`; and
    - `target_functional_tensor` is not symbolic (instead of using a functorch config)

cc bdhirsh miladm lezcano 

[ghstack-poisoned]
ysiraichi added a commit that referenced this issue May 10, 2024
…mbolic inputs."


Fix: #125387

This PR helps keep track of whether an instantiated `ViewMeta` has symbolic values as
input or not. This is used for checking whether we use the AOTAutograd `ViewMeta`-replay
execution path, e.g. it doesn't support tensors that have `ViewMeta` with symbolic inputs.

In summary, the changes are:

- Add the field `ViewMeta::has_symbolic_inputs` and make it a required constructor
parameter
- Add the field `FunctionalTensorWrapper::is_symbolic_` and the method
`FunctionalTensorWrapper::maybe_mark_symbolic`
    - Marks a `FunctionalTensorWrapper` as symbolic iff any of its `ViewMeta` have
    symbolic inputs
- Add the plumbing of `FunctionalTensorWrapper::is_symbolic` to the Python API
- Codegen the computation of `ViewMeta::has_symbolic_inputs` for each view operation
- Use the AOTAutograd `ViewMeta`-replay path if:
    - `target_functional_tensor` is not `None`; and
    - `target_functional_tensor` is not symbolic (instead of using a functorch config)

cc bdhirsh miladm lezcano 

[ghstack-poisoned]
ysiraichi added a commit that referenced this issue May 10, 2024
Fix: #125387

This PR helps keep track of whether an instantiated `ViewMeta` has symbolic values as
input or not. This is used for checking whether we use the AOTAutograd `ViewMeta`-replay
execution path, e.g. it doesn't support tensors that have `ViewMeta` with symbolic inputs.

In summary, the changes are:

- Add the field `ViewMeta::has_symbolic_inputs` and make it a required constructor
parameter
- Add the field `FunctionalTensorWrapper::is_symbolic_` and the method
`FunctionalTensorWrapper::maybe_mark_symbolic`
    - Marks a `FunctionalTensorWrapper` as symbolic iff any of its `ViewMeta` have
    symbolic inputs
- Add the plumbing of `FunctionalTensorWrapper::is_symbolic` to the Python API
- Codegen the computation of `ViewMeta::has_symbolic_inputs` for each view operation
- Use the AOTAutograd `ViewMeta`-replay path if:
    - `target_functional_tensor` is not `None`; and
    - `target_functional_tensor` is not symbolic (instead of using a functorch config)

cc bdhirsh miladm lezcano 

[ghstack-poisoned]
ysiraichi added a commit that referenced this issue May 10, 2024
Fix: #125387

This PR helps keep track of whether an instantiated `ViewMeta` has symbolic values as
input or not. This is used for checking whether we use the AOTAutograd `ViewMeta`-replay
execution path, e.g. it doesn't support tensors that have `ViewMeta` with symbolic inputs.

In summary, the changes are:

- Add the field `ViewMeta::has_symbolic_inputs` and make it a required constructor
parameter
- Add the field `FunctionalTensorWrapper::is_symbolic_` and the method
`FunctionalTensorWrapper::maybe_mark_symbolic`
    - Marks a `FunctionalTensorWrapper` as symbolic iff any of its `ViewMeta` have
    symbolic inputs
- Add the plumbing of `FunctionalTensorWrapper::is_symbolic` to the Python API
- Codegen the computation of `ViewMeta::has_symbolic_inputs` for each view operation
- Use the AOTAutograd `ViewMeta`-replay path if:
    - `target_functional_tensor` is not `None`; and
    - `target_functional_tensor` is not symbolic (instead of using a functorch config)

ghstack-source-id: 85f9dda6bb83014c0a16a7e2d2074d2cbe3ef24f
Pull Request resolved: #125876
@ydshieh
Copy link
Author

ydshieh commented May 13, 2024

Thanks for the fix @ysiraichi ! confirmed it works (with the provided script)

tinglvv pushed a commit to tinglvv/pytorch that referenced this issue May 14, 2024
Fix: pytorch#125387

This PR helps keep track of whether an instantiated `ViewMeta` has symbolic values as
input or not. This is used for checking whether we use the AOTAutograd `ViewMeta`-replay
execution path, e.g. it doesn't support tensors that have `ViewMeta` with symbolic inputs.

In summary, the changes are:

- Add the field `ViewMeta::has_symbolic_inputs` and make it a required constructor
parameter
- Add the field `FunctionalTensorWrapper::is_symbolic_` and the method
`FunctionalTensorWrapper::maybe_mark_symbolic`
    - Marks a `FunctionalTensorWrapper` as symbolic iff any of its `ViewMeta` have
    symbolic inputs
- Add the plumbing of `FunctionalTensorWrapper::is_symbolic` to the Python API
- Codegen the computation of `ViewMeta::has_symbolic_inputs` for each view operation
- Use the AOTAutograd `ViewMeta`-replay path if:
    - `target_functional_tensor` is not `None`; and
    - `target_functional_tensor` is not symbolic (instead of using a functorch config)

Pull Request resolved: pytorch#125876
Approved by: https://github.com/ezyang
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

6 participants