-
Notifications
You must be signed in to change notification settings - Fork 25.6k
Closed
Labels
oncall: cpu inductorCPU Inductor issues for Intel team to triageCPU Inductor issues for Intel team to triageoncall: pt2triagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate moduleThis issue has been looked at a team member, and triaged and prioritized into an appropriate module
Description
🐛 Describe the bug
maml_omniglot AMP Dynamic shape default wrapper accuracy and performance crashed
suite | name | thread | accuracy | perf | reason(reference only) |
---|---|---|---|---|---|
torchbench | maml_omniglot | multiple | X | N/A | maml_omniglot, AssertionError: expected size 64==64 stride 1==676 at dim=1 |
torchbench | maml_omniglot | single | X | N/A | maml_omniglot, AssertionError: expected size 64==64 stride 1==676 at dim=1 |
loading model: 0it [00:00, ?it/s]cpu eval maml_omniglot
skipping cudagraphs due to skipping cudagraphs due to cpu device. Found from :
File "benchmarks/dynamo/torchbench.py", line 357, in forward_pass
return mod(*inputs)
ERROR:common:Backend dynamo failed in warmup()
Traceback (most recent call last):
File "/workspace/pytorch/benchmarks/dynamo/common.py", line 2639, in warmup
fn(model, example_inputs)
File "/workspace/pytorch/torch/_dynamo/eval_frame.py", line 390, in _fn
return fn(*args, **kwargs)
File "benchmarks/dynamo/torchbench.py", line 355, in forward_pass
def forward_pass(self, mod, inputs, collect_outputs=True):
File "/workspace/pytorch/torch/_dynamo/eval_frame.py", line 390, in _fn
return fn(*args, **kwargs)
File "/workspace/pytorch/torch/_dynamo/external_utils.py", line 36, in inner
return fn(*args, **kwargs)
File "/workspace/pytorch/torch/_functorch/aot_autograd.py", line 917, in forward
return compiled_fn(full_args)
File "/workspace/pytorch/torch/_functorch/_aot_autograd/utils.py", line 89, in g
return f(*args)
File "/workspace/pytorch/torch/_functorch/_aot_autograd/runtime_wrappers.py", line 106, in runtime_wrapper
all_outs = call_func_at_runtime_with_args(
File "/workspace/pytorch/torch/_functorch/_aot_autograd/utils.py", line 113, in call_func_at_runtime_with_args
out = normalize_as_list(f(args))
File "/workspace/pytorch/torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py", line 152, in rng_functionalization_wrapper
return compiled_fw(args)
File "/workspace/pytorch/torch/_inductor/compile_fx.py", line 1093, in wrapper
return optimized_function(args_new)
File "/workspace/pytorch/torch/_inductor/codecache.py", line 937, in __call__
return self.current_callable(inputs)
File "/tmp/tmp2dbylmum/b6/cb6xejcvhq24bivlii5l4bnrokyshg25eaegfundva3emkz5hllr.py", line 251, in call
assert_size_stride(buf1, (s0, 64, 26, 26), (43264, 676, 26, 1))
AssertionError: expected size 64==64, stride 1==676 at dim=1
Run failed with return code: 255
Output: None
Error: None
Versions
SW info
name | target_branch | target_commit | refer_branch | refer_commit |
---|---|---|---|---|
torchbench | main | d6015d42 | main | 1ef0a39e |
torch | main | 5bc7f7f | main | 41286f1 |
torchvision | main | 0.18.0a0+0325175 | main | 0.18.0a0+2c127da |
torchtext | main | 0.16.0a0+b0ebddc | main | 0.16.0a0+b0ebddc |
torchaudio | main | 2.2.0a0+87aeb55 | main | 2.2.0a0+87aeb55 |
torchdata | main | 0.7.1a0+0790338 | main | 0.7.1a0+0790338 |
dynamo_benchmarks | main | nightly | main | nightly |
Repro:
inductor_single_run.sh
bash inductor_single_run.sh thread inference accuracy/performance torchbench maml_omniglot AMP first dynamic default 0
Suspected guilty commit: f2f8eee
torchbench-maml_omniglot-inference-amp-dynamic-default-single-accuracy-crash_guilty_commit.log
cc @ezyang @msaroufim @bdhirsh @anijain2305 @zou3519 @chauhang @WeizhuoZhang-intel @chuanqi129
Metadata
Metadata
Assignees
Labels
oncall: cpu inductorCPU Inductor issues for Intel team to triageCPU Inductor issues for Intel team to triageoncall: pt2triagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate moduleThis issue has been looked at a team member, and triaged and prioritized into an appropriate module