Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Dynamo failure due to non constants input of aten.lift_fresh_copy #113331

Closed
leslie-fang-intel opened this issue Nov 9, 2023 · 5 comments
Closed

Comments

@leslie-fang-intel
Copy link
Collaborator

leslie-fang-intel commented Nov 9, 2023

🐛 Describe the bug

For DistilBert defined in Transformer, there is runtime definition of mask fill value.

It cause the error when invoke torch.compile, the detail failure msg is:

  File "/home/jianan/leslie/torch_inductor_lz/pytorch/torch/utils/_stats.py", line 20, in wrapper
    return fn(*args, **kwargs)
  File "/home/jianan/leslie/torch_inductor_lz/pytorch/torch/_subclasses/fake_tensor.py", line 1390, in __torch_dispatch__
    return self.dispatch(func, types, args, kwargs)
  File "/home/jianan/leslie/torch_inductor_lz/pytorch/torch/_subclasses/fake_tensor.py", line 1482, in dispatch
    assert all(
torch._dynamo.exc.TorchRuntimeError: Failed running call_function aten.lift_fresh_copy.default(*(FakeTensor(..., size=()),), **{}):
aten.lift_fresh_copy.default should not have fake inputs without constants

from user code:
   File "<eval_with_key>.3", line 9, in forward
    lift_fresh_copy = torch.ops.aten.lift_fresh_copy.default(_tensor_constant0);  _tensor_constant0 = None

Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information


You can suppress this exception and fall back to eager by setting:
    import torch._dynamo
    torch._dynamo.config.suppress_errors = True

Example code to reproduce this failure:

import torch
import math
from torch._export import capture_pre_autograd_graph

class M(torch.nn.Module):
    def __init__(self,):
        super().__init__()

    def forward(self, scores, mask):
        scores = scores.masked_fill(
            mask, torch.tensor(torch.finfo(scores.dtype).min)
        )  # (bs, n_heads, q_length, k_length)
        return scores

if __name__ == "__main__":
    tensor_cpu = torch.randn(2, 4)
    mask_cpu = torch.BoolTensor(
        [[False,  True, False, False],
        [False, False, False, False]]
    )

    m = M().eval()
    # res_ref = m(tensor_cpu, mask_cpu)
    # print("res_ref is: {}".format(res_ref), flush=True)

    exported_model = capture_pre_autograd_graph(
        m,
        (tensor_cpu, mask_cpu),
    )
    print(exported_model, flush=True)
    optimized_model = torch.compile(exported_model)
    optimized_model(tensor_cpu, mask_cpu)

Versions

Collecting environment information...
PyTorch version: 2.2.0a0+gitff51f94
Is debug build: False
CUDA used to build PyTorch: None
ROCM used to build PyTorch: N/A

OS: Ubuntu 22.04.1 LTS (x86_64)
GCC version: (Ubuntu 9.5.0-1ubuntu1~22.04) 9.5.0
Clang version: Could not collect
CMake version: version 3.22.1
Libc version: glibc-2.35

Python version: 3.10.8 (main, Nov 24 2022, 14:13:03) [GCC 11.2.0] (64-bit runtime)
Python platform: Linux-5.15.0-53-generic-x86_64-with-glibc2.35
Is CUDA available: False
CUDA runtime version: No CUDA
CUDA_MODULE_LOADING set to: N/A
GPU models and configuration: No CUDA
Nvidia driver version: No CUDA
cuDNN version: No CUDA
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

CPU:
Architecture:                    x86_64
CPU op-mode(s):                  32-bit, 64-bit
Address sizes:                   52 bits physical, 57 bits virtual
Byte Order:                      Little Endian
CPU(s):                          224
On-line CPU(s) list:             0-223
Vendor ID:                       GenuineIntel
Model name:                      Intel (R) Xeon (R) CPU Max 9480
CPU family:                      6
Model:                           143
Thread(s) per core:              2
Core(s) per socket:              56
Socket(s):                       2
Stepping:                        8
CPU max MHz:                     3500.0000
CPU min MHz:                     800.0000
BogoMIPS:                        3800.00
Flags:                           fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush dts acpi mmx fxsr sse sse2 ss ht tm pbe syscall nx pdpe1gb rdtscp lm constant_tsc art arch_perfmon pebs bts rep_good nopl xtopology nonstop_tsc cpuid aperfmperf tsc_known_freq pni pclmulqdq dtes64 monitor ds_cpl vmx smx est tm2 ssse3 sdbg fma cx16 xtpr pdcm pcid dca sse4_1 sse4_2 x2apic movbe popcnt tsc_deadline_timer aes xsave avx f16c rdrand lahf_lm abm 3dnowprefetch cpuid_fault epb cat_l3 cat_l2 cdp_l3 invpcid_single intel_ppin cdp_l2 ssbd mba ibrs ibpb stibp ibrs_enhanced tpr_shadow vnmi flexpriority ept vpid ept_ad fsgsbase tsc_adjust bmi1 avx2 smep bmi2 erms invpcid cqm rdt_a avx512f avx512dq rdseed adx smap avx512ifma clflushopt clwb intel_pt avx512cd sha_ni avx512bw avx512vl xsaveopt xsavec xgetbv1 xsaves cqm_llc cqm_occup_llc cqm_mbm_total cqm_mbm_local split_lock_detect avx_vnni avx512_bf16 wbnoinvd dtherm ida arat pln pts hwp hwp_act_window hwp_epp hwp_pkg_req avx512vbmi umip pku ospke waitpkg avx512_vbmi2 gfni vaes vpclmulqdq avx512_vnni avx512_bitalg tme avx512_vpopcntdq la57 rdpid bus_lock_detect cldemote movdiri movdir64b enqcmd fsrm md_clear serialize tsxldtrk pconfig arch_lbr amx_bf16 avx512_fp16 amx_tile amx_int8 flush_l1d arch_capabilities
Virtualization:                  VT-x
L1d cache:                       5.3 MiB (112 instances)
L1i cache:                       3.5 MiB (112 instances)
L2 cache:                        224 MiB (112 instances)
L3 cache:                        225 MiB (2 instances)
NUMA node(s):                    4
NUMA node0 CPU(s):               0-55,112-167
NUMA node1 CPU(s):               56-111,168-223
NUMA node2 CPU(s):
NUMA node3 CPU(s):
Vulnerability Itlb multihit:     Not affected
Vulnerability L1tf:              Not affected
Vulnerability Mds:               Not affected
Vulnerability Meltdown:          Not affected
Vulnerability Mmio stale data:   Not affected
Vulnerability Retbleed:          Not affected
Vulnerability Spec store bypass: Mitigation; Speculative Store Bypass disabled via prctl and seccomp
Vulnerability Spectre v1:        Mitigation; usercopy/swapgs barriers and __user pointer sanitization
Vulnerability Spectre v2:        Mitigation; Enhanced IBRS, IBPB conditional, RSB filling, PBRSB-eIBRS SW sequence
Vulnerability Srbds:             Not affected
Vulnerability Tsx async abort:   Not affected

Versions of relevant libraries:
[pip3] intel-extension-for-pytorch==2.2.0+git5575b65
[pip3] numpy==1.26.0
[pip3] onnx==1.15.0
[pip3] torch==2.2.0a0+gitff51f94
[pip3] torchvision==0.17.0a0+f69eee6
[conda] blas                      1.0                         mkl
[conda] intel-extension-for-pytorch 2.2.0+git5575b65           dev_0    <develop>
[conda] mkl                       2023.1.0         h213fc3f_46343
[conda] mkl-include               2023.2.0                 pypi_0    pypi
[conda] mkl-service               2.4.0           py310h5eee18b_1
[conda] mkl-static                2023.2.0                 pypi_0    pypi
[conda] mkl_fft                   1.3.8           py310h5eee18b_0
[conda] mkl_random                1.2.4           py310hdb19cb5_0
[conda] numpy                     1.26.0          py310h5f9d8c6_0
[conda] numpy-base                1.26.0          py310hb5e798b_0
[conda] torch                     2.2.0a0+gitff51f94           dev_0    <develop>
[conda] torchvision               0.17.0a0+f69eee6          pypi_0    pypi

cc @ezyang @msaroufim @wconstab @bdhirsh @anijain2305 @zou3519 @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @chenyang78 @aakhundov @kadeng

@leslie-fang-intel leslie-fang-intel changed the title Dynamo failure in due to non constants input of aten.lift_fresh_copy Dynamo failure due to non constants input of aten.lift_fresh_copy Nov 9, 2023
@leslie-fang-intel
Copy link
Collaborator Author

leslie-fang-intel commented Nov 9, 2023

Further investigation, here is the debug_trace of this fake tensor:

34:
<FrameSummary file /home/jianan/leslie/torch_inductor_lz/pytorch/torch/_dynamo/variables/builtin.py, line 1094 in call_getattr>
35:
<FrameSummary file /home/jianan/leslie/torch_inductor_lz/pytorch/torch/_dynamo/variables/nn_module.py, line 218 in var_getattr>
36:
<FrameSummary file /home/jianan/leslie/torch_inductor_lz/pytorch/torch/_dynamo/variables/builder.py, line 237 in __call__>
37:
<FrameSummary file /home/jianan/leslie/torch_inductor_lz/pytorch/torch/_dynamo/variables/builder.py, line 382 in _wrap>
38:
<FrameSummary file /home/jianan/leslie/torch_inductor_lz/pytorch/torch/_dynamo/variables/builder.py, line 1027 in wrap_tensor>
39:
<FrameSummary file /home/jianan/leslie/torch_inductor_lz/pytorch/torch/_dynamo/output_graph.py, line 776 in register_attr_or_module>
40:
<FrameSummary file /home/jianan/leslie/torch_inductor_lz/pytorch/torch/_dynamo/output_graph.py, line 707 in wrap_name>
41:
<FrameSummary file /home/jianan/leslie/torch_inductor_lz/pytorch/torch/_dynamo/variables/builder.py, line 1353 in wrap_fx_proxy>
42:
<FrameSummary file /home/jianan/leslie/torch_inductor_lz/pytorch/torch/_dynamo/variables/builder.py, line 1469 in wrap_fx_proxy_cls>
43:
<FrameSummary file /home/jianan/leslie/torch_inductor_lz/pytorch/torch/_dynamo/variables/builder.py, line 1813 in wrap_to_fake_tensor_and_record>
44:
<FrameSummary file /home/jianan/leslie/torch_inductor_lz/pytorch/torch/_dynamo/utils.py, line 994 in wrap_fake_exception>
45:
<FrameSummary file /home/jianan/leslie/torch_inductor_lz/pytorch/torch/_dynamo/variables/builder.py, line 1814 in <lambda>>
46:
<FrameSummary file /home/jianan/leslie/torch_inductor_lz/pytorch/torch/_subclasses/fake_tensor.py, line 1878 in from_tensor>
47:
<FrameSummary file /home/jianan/leslie/torch_inductor_lz/pytorch/torch/_subclasses/fake_tensor.py, line 411 in __call__>
48:
<FrameSummary file /home/jianan/leslie/torch_inductor_lz/pytorch/torch/_subclasses/fake_tensor.py, line 364 in from_real_tensor>
49:
<FrameSummary file /home/jianan/leslie/torch_inductor_lz/pytorch/torch/_subclasses/meta_utils.py, line 707 in __call__>
50:
<FrameSummary file /home/jianan/leslie/torch_inductor_lz/pytorch/torch/_subclasses/meta_utils.py, line 516 in meta_tensor>
51:
<FrameSummary file /home/jianan/leslie/torch_inductor_lz/pytorch/torch/_subclasses/fake_tensor.py, line 357 in mk_fake_tensor>
52:
<FrameSummary file /home/jianan/leslie/torch_inductor_lz/pytorch/torch/_subclasses/fake_tensor.py, line 1137 in __new__>

Looks like in

return self.fake_tensor_converter(
self,
tensor,
shape_env=shape_env,
ignore_subclass=ignore_subclass,
source=source,
dynamic_dims=dynamic_dims,
constraint_dims=constraint_dims,
memoized_only=memoized_only,
)
parameter of make_constant does not turn on.

cc @jgong5

@leslie-fang-intel
Copy link
Collaborator Author

Hi @eellison, could you kindly help to take a look of this issue?

@jon-chuang
Copy link
Collaborator

This is is because dynamo currently doesn't initialize its fake tensors with constant. i.e. we have fake_tensor_converter(make_constant=False)

@jon-chuang
Copy link
Collaborator

jon-chuang commented Nov 9, 2023

I think it is fine to make_constant=True, as long as the tensor originates in graph, and allow FakeTensor to invalidate it when constant is false (i.e. if it is mutated in place).

However, we need to do a lot of work to ensure that we know when the tensor is going to be constant.


Further:

In the case of compiling the exported graph, the tensor source is a GraphModule(nn.Module).

It's even harder to prove it is constant in this case, compared with if it originates in a call to torch.tensor([..]) with constant inputs.

@leslie-fang-intel
Copy link
Collaborator Author

leslie-fang-intel commented Nov 16, 2023

Thanks @jon-chuang for the comment. @ezyang @wconstab do we have any idea for how to fix this issue? Appreciate if anyone can help looking into this issue.

ezyang added a commit that referenced this issue Nov 17, 2023
In this case, the input could be fake!  Just treat it normally in that case.

Fixes #113331

Signed-off-by: Edward Z. Yang <ezyang@meta.com>

[ghstack-poisoned]
ezyang added a commit that referenced this issue Nov 17, 2023
In this case, the input could be fake!  Just treat it normally in that case.

Fixes #113331

Signed-off-by: Edward Z. Yang <ezyangmeta.com>

cc avikchaudhuri gmagogsfm zhxchen17 tugsbayasgalan angelayi suo ydwu4

[ghstack-poisoned]
ezyang added a commit that referenced this issue Nov 17, 2023
In this case, the input could be fake!  Just treat it normally in that case.

Fixes #113331

Signed-off-by: Edward Z. Yang <ezyangmeta.com>

ghstack-source-id: 856afda31e24d4b5cba24b039d31688e7e7c58e0
Pull Request resolved: #113923
ezyang added a commit that referenced this issue Nov 19, 2023
In this case, the input could be fake!  Just treat it normally in that case.

Fixes #113331

Signed-off-by: Edward Z. Yang <ezyangmeta.com>

ghstack-source-id: ec7d3cf361746fd3369a3b920d09da922417466a
Pull Request resolved: #113923
ezyang added a commit that referenced this issue Nov 19, 2023
In this case, the input could be fake!  Just treat it normally in that case.

Fixes #113331

Signed-off-by: Edward Z. Yang <ezyangmeta.com>

cc avikchaudhuri gmagogsfm zhxchen17 tugsbayasgalan angelayi suo ydwu4

[ghstack-poisoned]
ezyang added a commit that referenced this issue Nov 19, 2023
In this case, the input could be fake!  Just treat it normally in that case.

Fixes #113331

Signed-off-by: Edward Z. Yang <ezyangmeta.com>

cc avikchaudhuri gmagogsfm zhxchen17 tugsbayasgalan angelayi suo ydwu4

[ghstack-poisoned]
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

5 participants