Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

[dynamo] Error "Inference tensors do not track version counter" in inference_mode w/ llama7b #101151

Closed
jgong5 opened this issue May 11, 2023 · 7 comments
Assignees

Comments

@jgong5
Copy link
Collaborator

jgong5 commented May 11, 2023

馃悰 Describe the bug

Transformers: 4.29.0

Repro:

from transformers import LlamaForCausalLM, LlamaTokenizer
import torch
from itertools import chain

# load model
print("Loading model...")
model_id = "decapoda-research/llama-7b-hf"
model = LlamaForCausalLM.from_pretrained(
    model_id, torch_dtype=torch.float32
)
tokenizer = LlamaTokenizer.from_pretrained(model_id)
print("Model loaded")
model = model.eval()

import torch._inductor.config as config
torch._dynamo.config.assume_static_by_default = False
model.generate = torch.compile(model.generate, dynamic=True)

print("Model initialized")

# input prompt
prompt = "Once upon a time, there existed a little girl who liked to have adventures. She wanted to go to places and meet new people, and have fun"

input_size = tokenizer(prompt, return_tensors="pt").input_ids.size(dim=1)
print("---- Prompt size:", input_size)

generate_kwargs = dict(do_sample=False, temperature=0.9, num_beams=4)
with torch.inference_mode(): # no problem without this
    input_ids = tokenizer(prompt, return_tensors="pt").input_ids
    output = model.generate(
        input_ids, max_new_tokens=32, **generate_kwargs
    )
    gen_ids = output
    gen_text = tokenizer.batch_decode(gen_ids, skip_special_tokens=True)

Error log:

[2023-05-11 10:43:51,807] torch._dynamo.output_graph: [INFO] Step 2: calling compiler function inductor
Traceback (most recent call last):
  File "run_llama_compile.py", line 30, in <module>
    output = model.generate(
  File "/home/jgong5/pytorch/torch/_dynamo/eval_frame.py", line 286, in _fn
    return fn(*args, **kwargs)
  File "/home/jgong5/pytorch/torch/_dynamo/external_utils.py", line 17, in inner
    return fn(*args, **kwargs)
  File "/home/jgong5/pytorch/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
  File "/home/jgong5/miniconda3/envs/pytorch/lib/python3.8/site-packages/transformers/generation/utils.py", line 1246, in generate
    self._validate_model_class()
  File "/home/jgong5/miniconda3/envs/pytorch/lib/python3.8/site-packages/transformers/generation/utils.py", line 1253, in <resume in generate>
    new_generation_config = GenerationConfig.from_model_config(self.config)
  File "/home/jgong5/miniconda3/envs/pytorch/lib/python3.8/site-packages/transformers/generation/utils.py", line 1604, in <resume in generate>
    return self.beam_search(
  File "/home/jgong5/miniconda3/envs/pytorch/lib/python3.8/site-packages/transformers/generation/utils.py", line 2837, in beam_search
    if len(stopping_criteria) == 0:
  File "/home/jgong5/miniconda3/envs/pytorch/lib/python3.8/site-packages/transformers/generation/utils.py", line 2902, in <resume in beam_search>
    outputs = self(
  File "/home/jgong5/pytorch/torch/nn/modules/module.py", line 1502, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/jgong5/pytorch/torch/nn/modules/module.py", line 1511, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/jgong5/miniconda3/envs/pytorch/lib/python3.8/site-packages/transformers/models/llama/modeling_llama.py", line 688, in forward
    outputs = self.model(
  File "/home/jgong5/pytorch/torch/nn/modules/module.py", line 1502, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/jgong5/pytorch/torch/nn/modules/module.py", line 1511, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/jgong5/miniconda3/envs/pytorch/lib/python3.8/site-packages/transformers/models/llama/modeling_llama.py", line 537, in forward
    attention_mask = self._prepare_decoder_attention_mask(
  File "/home/jgong5/pytorch/torch/_dynamo/eval_frame.py", line 439, in catch_errors
    return callback(frame, cache_size, hooks, frame_state)
  File "/home/jgong5/pytorch/torch/_dynamo/convert_frame.py", line 523, in _convert_frame
    result = inner_convert(frame, cache_size, hooks, frame_state)
  File "/home/jgong5/pytorch/torch/_dynamo/convert_frame.py", line 125, in _fn
    return fn(*args, **kwargs)
  File "/home/jgong5/pytorch/torch/_dynamo/convert_frame.py", line 358, in _convert_frame_assert
    return _compile(
  File "/home/jgong5/pytorch/torch/_dynamo/utils.py", line 177, in time_wrapper
    r = func(*args, **kwargs)
  File "/home/jgong5/pytorch/torch/_dynamo/convert_frame.py", line 428, in _compile
    out_code = transform_code_object(code, transform)
  File "/home/jgong5/pytorch/torch/_dynamo/bytecode_transformation.py", line 1000, in transform_code_object
    transformations(instructions, code_options)
  File "/home/jgong5/pytorch/torch/_dynamo/convert_frame.py", line 413, in transform
    tracer.run()
  File "/home/jgong5/pytorch/torch/_dynamo/symbolic_convert.py", line 2010, in run
    super().run()
  File "/home/jgong5/pytorch/torch/_dynamo/symbolic_convert.py", line 703, in run
    and self.step()
  File "/home/jgong5/pytorch/torch/_dynamo/symbolic_convert.py", line 663, in step
    getattr(self, inst.opname)(inst)
  File "/home/jgong5/pytorch/torch/_dynamo/symbolic_convert.py", line 2098, in RETURN_VALUE
    self.output.compile_subgraph(
  File "/home/jgong5/pytorch/torch/_dynamo/output_graph.py", line 736, in compile_subgraph
    self.compile_and_call_fx_graph(tx, pass2.graph_output_vars(), root)
  File "/home/jgong5/miniconda3/envs/pytorch/lib/python3.8/contextlib.py", line 75, in inner
    return func(*args, **kwds)
  File "/home/jgong5/pytorch/torch/_dynamo/output_graph.py", line 813, in compile_and_call_fx_graph
    compiled_fn = self.call_user_compiler(gm)
  File "/home/jgong5/pytorch/torch/_dynamo/utils.py", line 177, in time_wrapper
    r = func(*args, **kwargs)
  File "/home/jgong5/pytorch/torch/_dynamo/output_graph.py", line 872, in call_user_compiler
    raise BackendCompilerFailed(self.compiler_fn, e).with_traceback(
  File "/home/jgong5/pytorch/torch/_dynamo/output_graph.py", line 868, in call_user_compiler
    compiled_fn = compiler_fn(gm, self.example_inputs())
  File "/home/jgong5/pytorch/torch/_dynamo/repro/after_dynamo.py", line 108, in debug_wrapper
    compiled_gm = compiler_fn(gm, example_inputs)
  File "/home/jgong5/pytorch/torch/__init__.py", line 1531, in __call__
    return compile_fx(model_, inputs_, config_patches=self.config)
  File "/home/jgong5/pytorch/torch/_inductor/compile_fx.py", line 590, in compile_fx
    return compile_fx(
  File "/home/jgong5/pytorch/torch/_inductor/compile_fx.py", line 700, in compile_fx
    return aot_autograd(
  File "/home/jgong5/pytorch/torch/_dynamo/backends/common.py", line 55, in compiler_fn
    cg = aot_module_simplified(gm, example_inputs, **kwargs)
  File "/home/jgong5/pytorch/torch/_functorch/aot_autograd.py", line 3334, in aot_module_simplified
    compiled_fn = create_aot_dispatcher_function(
  File "/home/jgong5/pytorch/torch/_dynamo/utils.py", line 177, in time_wrapper
    r = func(*args, **kwargs)
  File "/home/jgong5/pytorch/torch/_functorch/aot_autograd.py", line 2959, in create_aot_dispatcher_function
    fw_metadata = run_functionalized_fw_and_collect_metadata(
  File "/home/jgong5/pytorch/torch/_functorch/aot_autograd.py", line 719, in inner
    flat_f_outs = f(*flat_f_args)
  File "/home/jgong5/pytorch/torch/_functorch/aot_autograd.py", line 3259, in functional_call
    out = Interpreter(mod).run(*args[params_len:], **kwargs)
  File "/home/jgong5/pytorch/torch/fx/interpreter.py", line 138, in run
    self.env[node] = self.run_node(node)
  File "/home/jgong5/pytorch/torch/fx/interpreter.py", line 198, in run_node
    return getattr(self, n.op)(n.target, args, kwargs)
  File "/home/jgong5/pytorch/torch/fx/interpreter.py", line 315, in call_module
    return submod(*args, **kwargs)
  File "/home/jgong5/pytorch/torch/nn/modules/module.py", line 1502, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/jgong5/pytorch/torch/nn/modules/module.py", line 1511, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/jgong5/pytorch/torch/nn/modules/linear.py", line 114, in forward
    return F.linear(input, self.weight, self.bias)
  File "/home/jgong5/pytorch/torch/_inductor/overrides.py", line 22, in __torch_function__
    return replace_fn(func, is_cpu_device(args))(*args, **kwargs)
torch._dynamo.exc.BackendCompilerFailed: backend='inductor' raised:
RuntimeError: Inference tensors do not track version counter.

While executing (module type: <class 'torch.nn.modules.linear.Linear'>) %l__self___layers_0_self_attn_q_proj : [#users=1] = call_module[target=L__self___layers_0_self_attn_q_proj](args = (%mul_1,), kwargs = {})
Original traceback:
  File "/home/jgong5/miniconda3/envs/pytorch/lib/python3.8/site-packages/transformers/models/llama/modeling_llama.py", line 578, in <resume in forward>
    layer_outputs = decoder_layer(
  File "/home/jgong5/pytorch/torch/nn/modules/module.py", line 1511, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/jgong5/miniconda3/envs/pytorch/lib/python3.8/site-packages/transformers/models/llama/modeling_llama.py", line 293, in forward
    hidden_states, self_attn_weights, present_key_value = self.self_attn(
  File "/home/jgong5/pytorch/torch/nn/modules/module.py", line 1511, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/jgong5/miniconda3/envs/pytorch/lib/python3.8/site-packages/transformers/models/llama/modeling_llama.py", line 197, in forward
    query_states = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)



You can suppress this exception and fall back to eager by setting:
    import torch._dynamo
    torch._dynamo.config.suppress_errors = True

Versions

PyTorch version: 2.1.0a0+gitb536c40
Is debug build: False
CUDA used to build PyTorch: None
ROCM used to build PyTorch: N/A

OS: Ubuntu 20.04.6 LTS (x86_64)
GCC version: (Ubuntu 9.4.0-1ubuntu1~20.04.1) 9.4.0
Clang version: Could not collect
CMake version: version 3.25.2
Libc version: glibc-2.31

Python version: 3.8.16 (default, Mar 2 2023, 03:21:46) [GCC 11.2.0] (64-bit runtime)
Python platform: Linux-5.15.0-71-generic-x86_64-with-glibc2.17
Is CUDA available: False
CUDA runtime version: No CUDA
CUDA_MODULE_LOADING set to: N/A
GPU models and configuration: No CUDA
Nvidia driver version: No CUDA
cuDNN version: No CUDA
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

CPU:
Architecture: x86_64
CPU op-mode(s): 32-bit, 64-bit
Byte Order: Little Endian
Address sizes: 52 bits physical, 57 bits virtual
CPU(s): 128
On-line CPU(s) list: 0-127
Thread(s) per core: 2
Core(s) per socket: 32
Socket(s): 2
NUMA node(s): 2
Vendor ID: GenuineIntel
CPU family: 6
Model: 106
Model name: Intel(R) Xeon(R) Platinum 8358 CPU @ 2.60GHz
Stepping: 6
CPU MHz: 2600.000
CPU max MHz: 3400.0000
CPU min MHz: 800.0000
BogoMIPS: 5200.00
L1d cache: 3 MiB
L1i cache: 2 MiB
L2 cache: 80 MiB
L3 cache: 96 MiB
NUMA node0 CPU(s): 0-31,64-95
NUMA node1 CPU(s): 32-63,96-127
Vulnerability Itlb multihit: Not affected
Vulnerability L1tf: Not affected
Vulnerability Mds: Not affected
Vulnerability Meltdown: Not affected
Vulnerability Mmio stale data: Mitigation; Clear CPU buffers; SMT vulnerable
Vulnerability Retbleed: Not affected
Vulnerability Spec store bypass: Mitigation; Speculative Store Bypass disabled via prctl and seccomp
Vulnerability Spectre v1: Mitigation; usercopy/swapgs barriers and __user pointer sanitization
Vulnerability Spectre v2: Mitigation; Enhanced IBRS, IBPB conditional, RSB filling, PBRSB-eIBRS SW sequence
Vulnerability Srbds: Not affected
Vulnerability Tsx async abort: Not affected
Flags: fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush dts acpi mmx fxsr sse sse2 ss ht tm pbe syscall nx pdpe1gb rdtscp lm constant_tsc art arch_perfmon pebs bts rep_good nopl xtopology nonstop_tsc cpuid aperfmperf pni pclmulqdq dtes64 monitor ds_cpl smx est tm2 ssse3 sdbg fma cx16 xtpr pdcm pcid dca sse4_1 sse4_2 x2apic movbe popcnt tsc_deadline_timer aes xsave avx f16c rdrand lahf_lm abm 3dnowprefetch cpuid_fault epb cat_l3 invpcid_single ssbd mba ibrs ibpb stibp ibrs_enhanced fsgsbase tsc_adjust bmi1 avx2 smep bmi2 erms invpcid cqm rdt_a avx512f avx512dq rdseed adx smap avx512ifma clflushopt clwb intel_pt avx512cd sha_ni avx512bw avx512vl xsaveopt xsavec xgetbv1 xsaves cqm_llc cqm_occup_llc cqm_mbm_total cqm_mbm_local split_lock_detect wbnoinvd dtherm ida arat pln pts hwp hwp_act_window hwp_epp hwp_pkg_req avx512vbmi umip pku ospke avx512_vbmi2 gfni vaes vpclmulqdq avx512_vnni avx512_bitalg tme avx512_vpopcntdq la57 rdpid fsrm md_clear pconfig flush_l1d arch_capabilities

Versions of relevant libraries:
[pip3] flake8==6.0.0
[pip3] flake8-bugbear==23.3.23
[pip3] flake8-comprehensions==3.12.0
[pip3] flake8-executable==2.1.3
[pip3] flake8-logging-format==0.9.0
[pip3] flake8-pyi==23.3.1
[pip3] flake8-simplify==0.19.3
[pip3] intel-extension-for-pytorch==2.1.0+git3642f0c
[pip3] mypy==0.960
[pip3] mypy-extensions==1.0.0
[pip3] numpy==1.23.1
[pip3] torch==2.1.0a0+gitd1f0c8e
[conda] intel-extension-for-pytorch 2.1.0+git3642f0c dev_0
[conda] mkl 2023.0.0 pypi_0 pypi
[conda] mkl-devel 2023.0.0 pypi_0 pypi
[conda] mkl-include 2023.0.0 pypi_0 pypi
[conda] mkl-static 2023.0.0 pypi_0 pypi
[conda] numpy 1.23.1 pypi_0 pypi
[conda] torch 2.1.0a0+gitd1f0c8e dev_0

cc @ezyang @soumith @msaroufim @wconstab @ngimel @bdhirsh @anijain2305

@bdhirsh
Copy link
Contributor

bdhirsh commented May 11, 2023

Inference mode is known to be pretty broken with torch.compile today, and we should fix it.

The easiest thing to do is to make inference mode a no-op inside of torch.compile: inference_mode is supposed to make your hotpath code cheaper by not having to store autograd metadata at runtime - but this is something we should be able to always avoid when using torch.compile (since we traced autograd ahead of time). I think inference_mode today just breaks during compilation.

@bdhirsh
Copy link
Contributor

bdhirsh commented May 11, 2023

Hmm, I made a local repro that I was able to fix with this PR: #101219

@jgong5 can you try running your repro on top of that PR and see if it fixes the issue?

bdhirsh added a commit that referenced this issue May 12, 2023
It looks like inference_mode wasn't playing well with functionalization.

If you run torch.compile on a function, and the inputs to the function are tensors created outside of inference mode, then we need to make sure that when we created functional tensor wrappers for those inputs during compilation, those functional wrappers properly mirror whether or not the original tensor is an inference tensor.

Hopefully fixes #101151




[ghstack-poisoned]
bdhirsh added a commit that referenced this issue May 12, 2023
It looks like inference_mode wasn't playing well with functionalization.

If you run torch.compile on a function, and the inputs to the function are tensors created outside of inference mode, then we need to make sure that when we created functional tensor wrappers for those inputs during compilation, those functional wrappers properly mirror whether or not the original tensor is an inference tensor.

Hopefully fixes #101151




[ghstack-poisoned]
bdhirsh added a commit that referenced this issue May 17, 2023
It looks like inference_mode wasn't playing well with functionalization.

If you run torch.compile on a function, and the inputs to the function are tensors created outside of inference mode, then we need to make sure that when we created functional tensor wrappers for those inputs during compilation, those functional wrappers properly mirror whether or not the original tensor is an inference tensor.

Hopefully fixes #101151




[ghstack-poisoned]
bdhirsh added a commit that referenced this issue May 17, 2023
It looks like inference_mode wasn't playing well with functionalization.

If you run torch.compile on a function, and the inputs to the function are tensors created outside of inference mode, then we need to make sure that when we created functional tensor wrappers for those inputs during compilation, those functional wrappers properly mirror whether or not the original tensor is an inference tensor.

Hopefully fixes #101151




[ghstack-poisoned]
bdhirsh added a commit that referenced this issue May 18, 2023
It looks like inference_mode wasn't playing well with functionalization.

If you run torch.compile on a function, and the inputs to the function are tensors created outside of inference mode, then we need to make sure that when we created functional tensor wrappers for those inputs during compilation, those functional wrappers properly mirror whether or not the original tensor is an inference tensor.

Hopefully fixes #101151




[ghstack-poisoned]
bdhirsh added a commit that referenced this issue May 18, 2023
It looks like inference_mode wasn't playing well with functionalization.

If you run torch.compile on a function, and the inputs to the function are tensors created outside of inference mode, then we need to make sure that when we created functional tensor wrappers for those inputs during compilation, those functional wrappers properly mirror whether or not the original tensor is an inference tensor.

Hopefully fixes #101151




[ghstack-poisoned]
bdhirsh added a commit that referenced this issue May 18, 2023
It looks like inference_mode wasn't playing well with functionalization.

If you run torch.compile on a function, and the inputs to the function are tensors created outside of inference mode, then we need to make sure that when we created functional tensor wrappers for those inputs during compilation, those functional wrappers properly mirror whether or not the original tensor is an inference tensor.

Hopefully fixes #101151




[ghstack-poisoned]
bdhirsh added a commit that referenced this issue May 18, 2023
It looks like inference_mode wasn't playing well with functionalization.

If you run torch.compile on a function, and the inputs to the function are tensors created outside of inference mode, then we need to make sure that when we created functional tensor wrappers for those inputs during compilation, those functional wrappers properly mirror whether or not the original tensor is an inference tensor.

Hopefully fixes #101151




[ghstack-poisoned]
bdhirsh added a commit that referenced this issue May 18, 2023
It looks like inference_mode wasn't playing well with functionalization.

If you run torch.compile on a function, and the inputs to the function are tensors created outside of inference mode, then we need to make sure that when we created functional tensor wrappers for those inputs during compilation, those functional wrappers properly mirror whether or not the original tensor is an inference tensor.

Hopefully fixes #101151




[ghstack-poisoned]
bdhirsh added a commit that referenced this issue May 18, 2023
It looks like inference_mode wasn't playing well with functionalization.

If you run torch.compile on a function, and the inputs to the function are tensors created outside of inference mode, then we need to make sure that when we created functional tensor wrappers for those inputs during compilation, those functional wrappers properly mirror whether or not the original tensor is an inference tensor.

Hopefully fixes #101151




[ghstack-poisoned]
bdhirsh added a commit that referenced this issue May 22, 2023
It looks like inference_mode wasn't playing well with functionalization.

If you run torch.compile on a function, and the inputs to the function are tensors created outside of inference mode, then we need to make sure that when we created functional tensor wrappers for those inputs during compilation, those functional wrappers properly mirror whether or not the original tensor is an inference tensor.

Hopefully fixes #101151




[ghstack-poisoned]
bdhirsh added a commit that referenced this issue May 22, 2023
It looks like inference_mode wasn't playing well with functionalization.

If you run torch.compile on a function, and the inputs to the function are tensors created outside of inference mode, then we need to make sure that when we created functional tensor wrappers for those inputs during compilation, those functional wrappers properly mirror whether or not the original tensor is an inference tensor.

Hopefully fixes #101151




[ghstack-poisoned]
bdhirsh added a commit that referenced this issue May 23, 2023
It looks like inference_mode wasn't playing well with functionalization.

If you run torch.compile on a function, and the inputs to the function are tensors created outside of inference mode, then we need to make sure that when we created functional tensor wrappers for those inputs during compilation, those functional wrappers properly mirror whether or not the original tensor is an inference tensor.

Hopefully fixes #101151




[ghstack-poisoned]
bdhirsh added a commit that referenced this issue May 23, 2023
It looks like inference_mode wasn't playing well with functionalization.

If you run torch.compile on a function, and the inputs to the function are tensors created outside of inference mode, then we need to make sure that when we created functional tensor wrappers for those inputs during compilation, those functional wrappers properly mirror whether or not the original tensor is an inference tensor.

Hopefully fixes #101151




[ghstack-poisoned]
bdhirsh added a commit that referenced this issue May 24, 2023
It looks like inference_mode wasn't playing well with functionalization.

If you run torch.compile on a function, and the inputs to the function are tensors created outside of inference mode, then we need to make sure that when we created functional tensor wrappers for those inputs during compilation, those functional wrappers properly mirror whether or not the original tensor is an inference tensor.

Hopefully fixes #101151




[ghstack-poisoned]
bdhirsh added a commit that referenced this issue May 24, 2023
It looks like inference_mode wasn't playing well with functionalization.

If you run torch.compile on a function, and the inputs to the function are tensors created outside of inference mode, then we need to make sure that when we created functional tensor wrappers for those inputs during compilation, those functional wrappers properly mirror whether or not the original tensor is an inference tensor.

Hopefully fixes #101151




[ghstack-poisoned]
pytorchmergebot pushed a commit that referenced this issue May 24, 2023
It looks like inference_mode wasn't playing well with functionalization.

If you run torch.compile on a function, and the inputs to the function are tensors created outside of inference mode, then we need to make sure that when we created functional tensor wrappers for those inputs during compilation, those functional wrappers properly mirror whether or not the original tensor is an inference tensor.

Hopefully fixes #101151

Pull Request resolved: #101219
Approved by: https://github.com/albanD, https://github.com/ezyang
@jgong5
Copy link
Collaborator Author

jgong5 commented Jun 7, 2023

Hmm, I made a local repro that I was able to fix with this PR: #101219

@jgong5 can you try running your repro on top of that PR and see if it fixes the issue?

@bdhirsh sorry, just saw the message. Someone else reported a similar issue: #103132. Will check if it is a duplicate too.

@ZhaoqiongZ
Copy link
Contributor

Hi @bdhirsh , I compiled pytorch with your PR branch with commit a469373.Here is the output.
`Loading model...
Loading checkpoint shards: 100%|鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅| 33/33 [00:15<00:00, 2.19it/s]
The tokenizer class you load from this checkpoint is not the same type as the class this function is called from. It may result in unexpected tokenization.
The tokenizer class you load from this checkpoint is 'LLaMATokenizer'.
The class this function is called from is 'LlamaTokenizer'.
Model loaded
Model initialized
---- Prompt size: 32
/home/zhaoqion/miniconda3/envs/103132/lib/python3.9/site-packages/transformers/generation/utils.py:1259: UserWarning: You have modified the pretrained model configuration to control generation. This is a deprecated strategy to control generation and will be removed soon, in a future version. Please use a generation configuration file (see https://huggingface.co/docs/transformers/main_classes/text_generation)
warnings.warn(
No CUDA runtime is found, using CUDA_HOME='/usr/local/cuda'
Traceback (most recent call last):
File "/home/zhaoqion/bug_reproduce/pytorch_issue/101151.py", line 30, in
output = model.generate(
File "/home/zhaoqion/miniconda3/envs/103132/lib/python3.9/site-packages/torch/_dynamo/eval_frame.py", line 287, in _fn
return fn(*args, **kwargs)
File "/home/zhaoqion/miniconda3/envs/103132/lib/python3.9/site-packages/torch/_dynamo/external_utils.py", line 17, in inner
return fn(*args, **kwargs)
File "/home/zhaoqion/miniconda3/envs/103132/lib/python3.9/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
return func(*args, **kwargs)
File "/home/zhaoqion/miniconda3/envs/103132/lib/python3.9/site-packages/transformers/generation/utils.py", line 1250, in generate
self._validate_model_class()
File "/home/zhaoqion/miniconda3/envs/103132/lib/python3.9/site-packages/transformers/generation/utils.py", line 1257, in
new_generation_config = GenerationConfig.from_model_config(self.config)
File "/home/zhaoqion/miniconda3/envs/103132/lib/python3.9/site-packages/transformers/generation/utils.py", line 1611, in
return self.beam_search(
File "/home/zhaoqion/miniconda3/envs/103132/lib/python3.9/site-packages/transformers/generation/utils.py", line 2844, in beam_search
if len(stopping_criteria) == 0:
File "/home/zhaoqion/miniconda3/envs/103132/lib/python3.9/site-packages/transformers/generation/utils.py", line 2909, in
outputs = self(
File "/home/zhaoqion/miniconda3/envs/103132/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1502, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/home/zhaoqion/miniconda3/envs/103132/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1511, in _call_impl
return forward_call(*args, **kwargs)
File "/home/zhaoqion/miniconda3/envs/103132/lib/python3.9/site-packages/transformers/models/llama/modeling_llama.py", line 688, in forward
outputs = self.model(
File "/home/zhaoqion/miniconda3/envs/103132/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1502, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/home/zhaoqion/miniconda3/envs/103132/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1511, in _call_impl
return forward_call(*args, **kwargs)
File "/home/zhaoqion/miniconda3/envs/103132/lib/python3.9/site-packages/transformers/models/llama/modeling_llama.py", line 537, in forward
attention_mask = self._prepare_decoder_attention_mask(
File "/home/zhaoqion/miniconda3/envs/103132/lib/python3.9/site-packages/torch/_dynamo/eval_frame.py", line 440, in catch_errors
return callback(frame, cache_size, hooks, frame_state)
File "/home/zhaoqion/miniconda3/envs/103132/lib/python3.9/site-packages/torch/_dynamo/convert_frame.py", line 527, in _convert_frame
result = inner_convert(frame, cache_size, hooks, frame_state)
File "/home/zhaoqion/miniconda3/envs/103132/lib/python3.9/site-packages/torch/_dynamo/convert_frame.py", line 127, in _fn
return fn(*args, **kwargs)
File "/home/zhaoqion/miniconda3/envs/103132/lib/python3.9/site-packages/torch/_dynamo/convert_frame.py", line 360, in _convert_frame_assert
return _compile(
File "/home/zhaoqion/miniconda3/envs/103132/lib/python3.9/site-packages/torch/_dynamo/utils.py", line 180, in time_wrapper
r = func(*args, **kwargs)
File "/home/zhaoqion/miniconda3/envs/103132/lib/python3.9/site-packages/torch/_dynamo/convert_frame.py", line 430, in _compile
out_code = transform_code_object(code, transform)
File "/home/zhaoqion/miniconda3/envs/103132/lib/python3.9/site-packages/torch/_dynamo/bytecode_transformation.py", line 1000, in transform_code_object
transformations(instructions, code_options)
File "/home/zhaoqion/miniconda3/envs/103132/lib/python3.9/site-packages/torch/_dynamo/convert_frame.py", line 415, in transform
tracer.run()
File "/home/zhaoqion/miniconda3/envs/103132/lib/python3.9/site-packages/torch/_dynamo/symbolic_convert.py", line 2019, in run
super().run()
File "/home/zhaoqion/miniconda3/envs/103132/lib/python3.9/site-packages/torch/_dynamo/symbolic_convert.py", line 703, in run
and self.step()
File "/home/zhaoqion/miniconda3/envs/103132/lib/python3.9/site-packages/torch/_dynamo/symbolic_convert.py", line 663, in step
getattr(self, inst.opname)(inst)
File "/home/zhaoqion/miniconda3/envs/103132/lib/python3.9/site-packages/torch/_dynamo/symbolic_convert.py", line 2107, in RETURN_VALUE
self.output.compile_subgraph(
File "/home/zhaoqion/miniconda3/envs/103132/lib/python3.9/site-packages/torch/_dynamo/output_graph.py", line 755, in compile_subgraph
self.compile_and_call_fx_graph(tx, pass2.graph_output_vars(), root)
File "/home/zhaoqion/miniconda3/envs/103132/lib/python3.9/contextlib.py", line 79, in inner
return func(*args, **kwds)
File "/home/zhaoqion/miniconda3/envs/103132/lib/python3.9/site-packages/torch/_dynamo/output_graph.py", line 832, in compile_and_call_fx_graph
compiled_fn = self.call_user_compiler(gm)
File "/home/zhaoqion/miniconda3/envs/103132/lib/python3.9/site-packages/torch/_dynamo/utils.py", line 180, in time_wrapper
r = func(*args, **kwargs)
File "/home/zhaoqion/miniconda3/envs/103132/lib/python3.9/site-packages/torch/_dynamo/output_graph.py", line 891, in call_user_compiler
raise BackendCompilerFailed(self.compiler_fn, e).with_traceback(
File "/home/zhaoqion/miniconda3/envs/103132/lib/python3.9/site-packages/torch/dynamo/output_graph.py", line 887, in call_user_compiler
compiled_fn = compiler_fn(gm, self.example_inputs())
File "/home/zhaoqion/miniconda3/envs/103132/lib/python3.9/site-packages/torch/dynamo/repro/after_dynamo.py", line 117, in debug_wrapper
compiled_gm = compiler_fn(gm, example_inputs)
File "/home/zhaoqion/miniconda3/envs/103132/lib/python3.9/site-packages/torch/init.py", line 1543, in call
return compile_fx(model
, inputs
, config_patches=self.config)
File "/home/zhaoqion/miniconda3/envs/103132/lib/python3.9/site-packages/torch/_inductor/compile_fx.py", line 610, in compile_fx
return compile_fx(
File "/home/zhaoqion/miniconda3/envs/103132/lib/python3.9/site-packages/torch/_inductor/compile_fx.py", line 720, in compile_fx
return aot_autograd(
File "/home/zhaoqion/miniconda3/envs/103132/lib/python3.9/site-packages/torch/_dynamo/backends/common.py", line 55, in compiler_fn
cg = aot_module_simplified(gm, example_inputs, **kwargs)
File "/home/zhaoqion/miniconda3/envs/103132/lib/python3.9/site-packages/torch/_functorch/aot_autograd.py", line 3700, in aot_module_simplified
compiled_fn = create_aot_dispatcher_function(
File "/home/zhaoqion/miniconda3/envs/103132/lib/python3.9/site-packages/torch/_dynamo/utils.py", line 180, in time_wrapper
r = func(*args, **kwargs)
File "/home/zhaoqion/miniconda3/envs/103132/lib/python3.9/site-packages/torch/_functorch/aot_autograd.py", line 3176, in create_aot_dispatcher_function
fw_metadata = run_functionalized_fw_and_collect_metadata(
File "/home/zhaoqion/miniconda3/envs/103132/lib/python3.9/site-packages/torch/_functorch/aot_autograd.py", line 722, in inner
flat_f_outs = f(*flat_f_args)
File "/home/zhaoqion/miniconda3/envs/103132/lib/python3.9/site-packages/torch/_functorch/aot_autograd.py", line 3309, in functional_call
out = Interpreter(mod).run(*args[params_len:], **kwargs)
File "/home/zhaoqion/miniconda3/envs/103132/lib/python3.9/site-packages/torch/fx/interpreter.py", line 138, in run
self.env[node] = self.run_node(node)
File "/home/zhaoqion/miniconda3/envs/103132/lib/python3.9/site-packages/torch/fx/interpreter.py", line 195, in run_node
return getattr(self, n.op)(n.target, args, kwargs)
File "/home/zhaoqion/miniconda3/envs/103132/lib/python3.9/site-packages/torch/fx/interpreter.py", line 312, in call_module
return submod(*args, **kwargs)
File "/home/zhaoqion/miniconda3/envs/103132/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1502, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/home/zhaoqion/miniconda3/envs/103132/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1511, in _call_impl
return forward_call(*args, **kwargs)
File "/home/zhaoqion/miniconda3/envs/103132/lib/python3.9/site-packages/torch/nn/modules/linear.py", line 114, in forward
return F.linear(input, self.weight, self.bias)
torch._dynamo.exc.BackendCompilerFailed: backend='inductor' raised:
RuntimeError: Cannot call sizes() on tensor with symbolic sizes/strides

While executing %l__self___layers_0_self_attn_q_proj : [#users=1] = call_module[target=L__self___layers_0_self_attn_q_proj](args = (%to_1,), kwargs = {})
Original traceback:
File "/home/zhaoqion/miniconda3/envs/103132/lib/python3.9/site-packages/transformers/models/llama/modeling_llama.py", line 578, in
layer_outputs = decoder_layer(
File "/home/zhaoqion/miniconda3/envs/103132/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1511, in _call_impl
return forward_call(*args, **kwargs)
File "/home/zhaoqion/miniconda3/envs/103132/lib/python3.9/site-packages/transformers/models/llama/modeling_llama.py", line 292, in forward
hidden_states, self_attn_weights, present_key_value = self.self_attn(
File "/home/zhaoqion/miniconda3/envs/103132/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1511, in _call_impl
return forward_call(*args, **kwargs)
File "/home/zhaoqion/miniconda3/envs/103132/lib/python3.9/site-packages/transformers/models/llama/modeling_llama.py", line 194, in forward
query_states = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)

You can suppress this exception and fall back to eager by setting:
import torch._dynamo
torch._dynamo.config.suppress_errors = True`

And I also try to repro on similar issue #103132
Here is the output
Downloading model.safetensors: 100%|鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻坾 990M/990M [01:55<00:00, 8.59MB/s] /home/zhaoqion/miniconda3/envs/103132/lib/python3.9/site-packages/transformers/generation/utils.py:1353: UserWarning: Using max_length's default (20) to control the generation length. This behaviour is deprecated and will be removed from the config in v5 of Transformers -- we recommend using max_new_tokensto control the maximum length of the generation. warnings.warn( eager eval time 0: 0.6019136905670166 eager eval time 1: 0.45739245414733887 eager eval time 2: 0.41426658630371094 eager eval time 3: 0.4186084270477295 eager eval time 4: 0.4126737117767334 eager eval time 5: 0.41367483139038086 eager eval time 6: 0.40696215629577637 eager eval time 7: 0.4067347049713135 eager eval time 8: 0.4078559875488281 eager eval time 9: 0.4479329586029053 huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks... To disable this warning, you can either: - Avoid usingtokenizersbefore the fork if possible - Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false) huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks... To disable this warning, you can either: - Avoid usingtokenizersbefore the fork if possible - Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false) huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks... To disable this warning, you can either: - Avoid usingtokenizersbefore the fork if possible - Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false) huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks... To disable this warning, you can either: - Avoid usingtokenizersbefore the fork if possible - Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false) huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks... To disable this warning, you can either: - Avoid usingtokenizersbefore the fork if possible - Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false) huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks... To disable this warning, you can either: - Avoid usingtokenizersbefore the fork if possible - Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false) huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks... To disable this warning, you can either: - Avoid usingtokenizersbefore the fork if possible - Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false) huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks... To disable this warning, you can either: - Avoid usingtokenizersbefore the fork if possible - Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false) huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks... To disable this warning, you can either: - Avoid usingtokenizersbefore the fork if possible - Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false) huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks... To disable this warning, you can either: - Avoid usingtokenizersbefore the fork if possible - Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false) huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks... To disable this warning, you can either: - Avoid usingtokenizersbefore the fork if possible - Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false) huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks... To disable this warning, you can either: - Avoid usingtokenizersbefore the fork if possible - Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false) huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks... To disable this warning, you can either: - Avoid usingtokenizersbefore the fork if possible - Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false) huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks... To disable this warning, you can either: - Avoid usingtokenizersbefore the fork if possible - Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false) huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks... To disable this warning, you can either: - Avoid usingtokenizersbefore the fork if possible - Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false) huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks... To disable this warning, you can either: - Avoid usingtokenizersbefore the fork if possible - Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false) huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks... To disable this warning, you can either: - Avoid usingtokenizersbefore the fork if possible - Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false) huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks... To disable this warning, you can either: - Avoid usingtokenizersbefore the fork if possible - Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false) huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks... To disable this warning, you can either: - Avoid usingtokenizersbefore the fork if possible - Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false) huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks... To disable this warning, you can either: - Avoid usingtokenizersbefore the fork if possible - Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false) huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks... To disable this warning, you can either: - Avoid usingtokenizersbefore the fork if possible - Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false) huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks... To disable this warning, you can either: - Avoid usingtokenizersbefore the fork if possible - Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false) huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks... To disable this warning, you can either: - Avoid usingtokenizersbefore the fork if possible - Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false) huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks... To disable this warning, you can either: - Avoid usingtokenizersbefore the fork if possible - Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false) huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks... To disable this warning, you can either: - Avoid usingtokenizersbefore the fork if possible - Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false) huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks... To disable this warning, you can either: - Avoid usingtokenizersbefore the fork if possible - Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false) huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks... To disable this warning, you can either: - Avoid usingtokenizersbefore the fork if possible - Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false) huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks... To disable this warning, you can either: - Avoid usingtokenizersbefore the fork if possible - Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false) huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks... To disable this warning, you can either: - Avoid usingtokenizersbefore the fork if possible - Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false) huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks... To disable this warning, you can either: - Avoid usingtokenizersbefore the fork if possible - Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false) huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks... To disable this warning, you can either: - Avoid usingtokenizersbefore the fork if possible - Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false) huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks... To disable this warning, you can either: - Avoid usingtokenizersbefore the fork if possible - Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false) huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks... To disable this warning, you can either: - Avoid usingtokenizersbefore the fork if possible - Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false) huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks... To disable this warning, you can either: - Avoid usingtokenizersbefore the fork if possible - Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false) huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks... To disable this warning, you can either: - Avoid usingtokenizersbefore the fork if possible - Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false) No CUDA runtime is found, using CUDA_HOME='/usr/local/cuda' huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks... To disable this warning, you can either: - Avoid usingtokenizersbefore the fork if possible - Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false) huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks... To disable this warning, you can either: - Avoid usingtokenizersbefore the fork if possible - Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false) huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks... To disable this warning, you can either: - Avoid usingtokenizersbefore the fork if possible - Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false) huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks... To disable this warning, you can either: - Avoid usingtokenizersbefore the fork if possible - Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false) /home/zhaoqion/miniconda3/envs/103132/lib/python3.9/site-packages/transformers/generation/utils.py:1353: UserWarning: Usingmax_length's default (20) to control the generation length. This behaviour is deprecated and will be removed from the config in v5 of Transformers -- we recommend using max_new_tokens` to control the maximum length of the generation.
warnings.warn(
Traceback (most recent call last):
File "/home/zhaoqion/bug_reproduce/pytorch_issue/103132.py", line 23, in
output_compile = model.generate(**inputs)
File "/home/zhaoqion/miniconda3/envs/103132/lib/python3.9/site-packages/torch/_dynamo/eval_frame.py", line 287, in _fn
return fn(*args, **kwargs)
File "/home/zhaoqion/miniconda3/envs/103132/lib/python3.9/site-packages/torch/_dynamo/external_utils.py", line 17, in inner
return fn(*args, **kwargs)
File "/home/zhaoqion/miniconda3/envs/103132/lib/python3.9/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
return func(*args, **kwargs)
File "/home/zhaoqion/miniconda3/envs/103132/lib/python3.9/site-packages/transformers/generation/utils.py", line 1250, in generate
self._validate_model_class()
File "/home/zhaoqion/miniconda3/envs/103132/lib/python3.9/site-packages/transformers/generation/utils.py", line 1257, in
new_generation_config = GenerationConfig.from_model_config(self.config)
File "/home/zhaoqion/miniconda3/envs/103132/lib/python3.9/site-packages/transformers/generation/utils.py", line 1522, in
return self.greedy_search(
File "/home/zhaoqion/miniconda3/envs/103132/lib/python3.9/site-packages/transformers/generation/utils.py", line 2339, in greedy_search
outputs = self(
File "/home/zhaoqion/miniconda3/envs/103132/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1502, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/home/zhaoqion/miniconda3/envs/103132/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1511, in _call_impl
return forward_call(*args, **kwargs)
File "/home/zhaoqion/miniconda3/envs/103132/lib/python3.9/site-packages/torch/_dynamo/eval_frame.py", line 440, in catch_errors
return callback(frame, cache_size, hooks, frame_state)
File "/home/zhaoqion/miniconda3/envs/103132/lib/python3.9/site-packages/torch/_dynamo/convert_frame.py", line 527, in _convert_frame
result = inner_convert(frame, cache_size, hooks, frame_state)
File "/home/zhaoqion/miniconda3/envs/103132/lib/python3.9/site-packages/torch/_dynamo/convert_frame.py", line 127, in _fn
return fn(*args, **kwargs)
File "/home/zhaoqion/miniconda3/envs/103132/lib/python3.9/site-packages/torch/_dynamo/convert_frame.py", line 360, in _convert_frame_assert
return _compile(
File "/home/zhaoqion/miniconda3/envs/103132/lib/python3.9/site-packages/torch/_dynamo/utils.py", line 180, in time_wrapper
r = func(*args, **kwargs)
File "/home/zhaoqion/miniconda3/envs/103132/lib/python3.9/site-packages/torch/_dynamo/convert_frame.py", line 430, in _compile
out_code = transform_code_object(code, transform)
File "/home/zhaoqion/miniconda3/envs/103132/lib/python3.9/site-packages/torch/_dynamo/bytecode_transformation.py", line 1000, in transform_code_object
transformations(instructions, code_options)
File "/home/zhaoqion/miniconda3/envs/103132/lib/python3.9/site-packages/torch/_dynamo/convert_frame.py", line 415, in transform
tracer.run()
File "/home/zhaoqion/miniconda3/envs/103132/lib/python3.9/site-packages/torch/_dynamo/symbolic_convert.py", line 2019, in run
super().run()
File "/home/zhaoqion/miniconda3/envs/103132/lib/python3.9/site-packages/torch/_dynamo/symbolic_convert.py", line 703, in run
and self.step()
File "/home/zhaoqion/miniconda3/envs/103132/lib/python3.9/site-packages/torch/_dynamo/symbolic_convert.py", line 663, in step
getattr(self, inst.opname)(inst)
File "/home/zhaoqion/miniconda3/envs/103132/lib/python3.9/site-packages/torch/_dynamo/symbolic_convert.py", line 2107, in RETURN_VALUE
self.output.compile_subgraph(
File "/home/zhaoqion/miniconda3/envs/103132/lib/python3.9/site-packages/torch/_dynamo/output_graph.py", line 755, in compile_subgraph
self.compile_and_call_fx_graph(tx, pass2.graph_output_vars(), root)
File "/home/zhaoqion/miniconda3/envs/103132/lib/python3.9/contextlib.py", line 79, in inner
return func(*args, **kwds)
File "/home/zhaoqion/miniconda3/envs/103132/lib/python3.9/site-packages/torch/_dynamo/output_graph.py", line 832, in compile_and_call_fx_graph
compiled_fn = self.call_user_compiler(gm)
File "/home/zhaoqion/miniconda3/envs/103132/lib/python3.9/site-packages/torch/_dynamo/utils.py", line 180, in time_wrapper
r = func(*args, **kwargs)
File "/home/zhaoqion/miniconda3/envs/103132/lib/python3.9/site-packages/torch/_dynamo/output_graph.py", line 891, in call_user_compiler
raise BackendCompilerFailed(self.compiler_fn, e).with_traceback(
File "/home/zhaoqion/miniconda3/envs/103132/lib/python3.9/site-packages/torch/dynamo/output_graph.py", line 887, in call_user_compiler
compiled_fn = compiler_fn(gm, self.example_inputs())
File "/home/zhaoqion/miniconda3/envs/103132/lib/python3.9/site-packages/torch/dynamo/repro/after_dynamo.py", line 117, in debug_wrapper
compiled_gm = compiler_fn(gm, example_inputs)
File "/home/zhaoqion/miniconda3/envs/103132/lib/python3.9/site-packages/torch/init.py", line 1543, in call
return compile_fx(model
, inputs
, config_patches=self.config)
File "/home/zhaoqion/miniconda3/envs/103132/lib/python3.9/site-packages/torch/_inductor/compile_fx.py", line 610, in compile_fx
return compile_fx(
File "/home/zhaoqion/miniconda3/envs/103132/lib/python3.9/site-packages/torch/_inductor/compile_fx.py", line 720, in compile_fx
return aot_autograd(
File "/home/zhaoqion/miniconda3/envs/103132/lib/python3.9/site-packages/torch/_dynamo/backends/common.py", line 55, in compiler_fn
cg = aot_module_simplified(gm, example_inputs, **kwargs)
File "/home/zhaoqion/miniconda3/envs/103132/lib/python3.9/site-packages/torch/_functorch/aot_autograd.py", line 3700, in aot_module_simplified
compiled_fn = create_aot_dispatcher_function(
File "/home/zhaoqion/miniconda3/envs/103132/lib/python3.9/site-packages/torch/_dynamo/utils.py", line 180, in time_wrapper
r = func(*args, **kwargs)
File "/home/zhaoqion/miniconda3/envs/103132/lib/python3.9/site-packages/torch/_functorch/aot_autograd.py", line 3176, in create_aot_dispatcher_function
fw_metadata = run_functionalized_fw_and_collect_metadata(
File "/home/zhaoqion/miniconda3/envs/103132/lib/python3.9/site-packages/torch/_functorch/aot_autograd.py", line 722, in inner
flat_f_outs = f(*flat_f_args)
File "/home/zhaoqion/miniconda3/envs/103132/lib/python3.9/site-packages/torch/_functorch/aot_autograd.py", line 3309, in functional_call
out = Interpreter(mod).run(*args[params_len:], **kwargs)
File "/home/zhaoqion/miniconda3/envs/103132/lib/python3.9/site-packages/torch/fx/interpreter.py", line 138, in run
self.env[node] = self.run_node(node)
File "/home/zhaoqion/miniconda3/envs/103132/lib/python3.9/site-packages/torch/fx/interpreter.py", line 195, in run_node
return getattr(self, n.op)(n.target, args, kwargs)
File "/home/zhaoqion/miniconda3/envs/103132/lib/python3.9/site-packages/torch/fx/interpreter.py", line 267, in call_function
return target(*args, **kwargs)
torch._dynamo.exc.BackendCompilerFailed: backend='inductor' raised:
RuntimeError: Cannot call sizes() on tensor with symbolic sizes/strides

While executing %matmul : [#users=1] = call_function[target=torch.matmul](args = (%transpose, %transpose_3), kwargs = {})
Original traceback:
File "/home/zhaoqion/miniconda3/envs/103132/lib/python3.9/site-packages/transformers/models/t5/modeling_t5.py", line 1720, in forward
decoder_outputs = self.decoder(
File "/home/zhaoqion/miniconda3/envs/103132/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1511, in _call_impl
return forward_call(*args, **kwargs)
File "/home/zhaoqion/miniconda3/envs/103132/lib/python3.9/site-packages/transformers/models/t5/modeling_t5.py", line 1090, in forward
layer_outputs = layer_module(
File "/home/zhaoqion/miniconda3/envs/103132/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1511, in _call_impl
return forward_call(*args, **kwargs)
File "/home/zhaoqion/miniconda3/envs/103132/lib/python3.9/site-packages/transformers/models/t5/modeling_t5.py", line 693, in forward
self_attention_outputs = self.layer[0](
File "/home/zhaoqion/miniconda3/envs/103132/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1511, in _call_impl
return forward_call(*args, **kwargs)
File "/home/zhaoqion/miniconda3/envs/103132/lib/python3.9/site-packages/transformers/models/t5/modeling_t5.py", line 600, in forward
attention_output = self.SelfAttention(
File "/home/zhaoqion/miniconda3/envs/103132/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1511, in _call_impl
return forward_call(*args, **kwargs)
File "/home/zhaoqion/miniconda3/envs/103132/lib/python3.9/site-packages/transformers/models/t5/modeling_t5.py", line 530, in forward
scores = torch.matmul(

You can suppress this exception and fall back to eager by setting:
import torch._dynamo
torch._dynamo.config.suppress_errors = True`

@bdhirsh
Copy link
Contributor

bdhirsh commented Jun 8, 2023

Thanks for the callout @ZhaoqiongZ - I have another fix here that should fix this one (also linked in the new issue) #103275

@cs-mshah
Copy link

I just recently faced an this issue in inference_mode with the error:

torch._dynamo.exc.BackendCompilerFailed: debug_wrapper raised RuntimeError: Inference tensors do not track version counter.

123epsilon added a commit to nod-ai/SHARK-Turbine that referenced this issue Aug 22, 2023
Add a test case for passing llama through the`turbine_cpu` backend. This
replaces all fairscale layers with corresponding vanilla torch layers
for simplicity, but we can add these back later once we have llama
working. Also removes the `@torch.inference_mode()` decorator to avoid
the issue documented
[here](pytorch/pytorch#101151), which is not
necessarily relevant to the quality of our pipeline.
@ycattan
Copy link

ycattan commented Jan 15, 2024

@cs-mshah did you find a working solution ?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

7 participants