Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. Weโ€™ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

irrelevant error output for Minified repro #97750

Open
allanchan339 opened this issue Mar 28, 2023 · 1 comment
Open

irrelevant error output for Minified repro #97750

allanchan339 opened this issue Mar 28, 2023 · 1 comment
Labels
module: minifier oncall: pt2 triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@allanchan339
Copy link

allanchan339 commented Mar 28, 2023

๐Ÿ› Describe the bug

I am trying to locate the error after using torch.compile() to my unet.
The scope of torch.compile() is added to self.unet only.
The error:

็™ผ็”Ÿไพ‹ๅค–็‹€ๆณ: BackendCompilerFailed
debug_wrapper raised DataDependentOutputException: aten._local_scalar_dense.default


You can suppress this exception and fall back to eager by setting:
    torch._dynamo.config.suppress_errors = True
torch._subclasses.fake_tensor.DataDependentOutputException: aten._local_scalar_dense.default

The above exception was the direct cause of the following exception:
File "/code/EnlightDiff/diffusion.py", line 216, in p_losses
    noise_pred = self.unet(x_t, t, cond)
  File "/code/EnlightDiff/diffusion.py", line 223, in forward
    noise_pred, noise = self.p_losses(x_start, t, cond, center, noise=noise)
  File "/code/EnlightDiff/diffusion.py", line 489, in training_step
    noise_pred, noise = self.model(x_start, t, cond, center)
  File "/code/EnlightDiff/main.py", line 341, in main
    trainer.fit(model=litmodel, datamodule=litdataModule)
  File "/code/EnlightDiff/main.py", line 363, in <module>
    main(use_LOL4K=use_LOL4K, on_diffusion=on_diffusion, on_encoder=on_encoder,
torch._dynamo.exc.BackendCompilerFailed: debug_wrapper raised DataDependentOutputException: aten._local_scalar_dense.default


You can suppress this exception and fall back to eager by setting:
    torch._dynamo.config.suppress_errors = True

Therefore, follow up issue #97749, I used cuda:0 to produce the minified repro again.

The exception raised seems to be irrelevant and helpless to locate the runtime error for torch.compile()

Error logs

[2023-03-28 14:21:39,824] torch._dynamo.debug_utils: [WARNING] Compiled Fx GraphModule failed. Creating script to minify the error.
[2023-03-28 14:21:39,828] torch._dynamo.debug_utils: [WARNING] Writing checkpoint with 19 nodes to /code/EnlightDiff/torch_compile_debug/run_2023_03_28_13_58_25_580581-pid_1719185/minifier/checkpoints/minified_19_nodes.py
[2023-03-28 14:21:39,829] torch._dynamo.debug_utils: [WARNING] Copying /code/EnlightDiff/torch_compile_debug/run_2023_03_28_13_58_25_580581-pid_1719185/minifier/checkpoints/minified_19_nodes.py to /code/EnlightDiff/torch_compile_debug/run_2023_03_28_13_58_25_580581-pid_1719185/minifier/repro.py for convenience
Traceback (most recent call last):
File "/home/cychan/mambaforge/envs/ldm3/lib/python3.10/site-packages/torch/_dynamo/debug_utils.py", line 1073, in dynamo_minifier_backend
raise ValueError("No issue was detected")
ValueError: No issue was detected

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
File "/home/cychan/mambaforge/envs/ldm3/lib/python3.10/site-packages/torch/_dynamo/output_graph.py", line 666, in call_user_compiler
compiled_fn = compiler_fn(gm, self.example_inputs())
File "/home/cychan/mambaforge/envs/ldm3/lib/python3.10/site-packages/torch/_dynamo/debug_utils.py", line 1055, in debug_wrapper
compiled_gm = compiler_fn(gm, example_inputs)
File "/home/cychan/mambaforge/envs/ldm3/lib/python3.10/site-packages/torch/_dynamo/debug_utils.py", line 1088, in dynamo_minifier_backend
minifier(
File "/home/cychan/mambaforge/envs/ldm3/lib/python3.10/site-packages/torch/_functorch/fx_minifier.py", line 97, in minifier
raise RuntimeError("Input graph did not fail the tester")
RuntimeError: Input graph did not fail the tester

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
File "/code/EnlightDiff/torch_compile_debug/run_2023_03_28_13_58_25_580581-pid_1719185/minifier/minifier_launcher.py", line 69, in
opt_mod(*args)
File "/home/cychan/mambaforge/envs/ldm3/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
return forward_call(*args, **kwargs)
File "/home/cychan/mambaforge/envs/ldm3/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 82, in forward
return self.dynamo_ctx(self._orig_mod.forward)(*args, **kwargs)
File "/home/cychan/mambaforge/envs/ldm3/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 209, in _fn
return fn(*args, **kwargs)
File "/home/cychan/mambaforge/envs/ldm3/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 337, in catch_errors
return callback(frame, cache_size, hooks)
File "/home/cychan/mambaforge/envs/ldm3/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 404, in _convert_frame
result = inner_convert(frame, cache_size, hooks)
File "/home/cychan/mambaforge/envs/ldm3/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 104, in _fn
return fn(*args, **kwargs)
File "/home/cychan/mambaforge/envs/ldm3/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 262, in _convert_frame_assert
return _compile(
File "/home/cychan/mambaforge/envs/ldm3/lib/python3.10/site-packages/torch/_dynamo/utils.py", line 163, in time_wrapper
r = func(*args, **kwargs)
File "/home/cychan/mambaforge/envs/ldm3/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 324, in _compile
out_code = transform_code_object(code, transform)
File "/home/cychan/mambaforge/envs/ldm3/lib/python3.10/site-packages/torch/_dynamo/bytecode_transformation.py", line 445, in transform_code_object
transformations(instructions, code_options)
File "/home/cychan/mambaforge/envs/ldm3/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 311, in transform
tracer.run()
File "/home/cychan/mambaforge/envs/ldm3/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 1726, in run
super().run()
File "/home/cychan/mambaforge/envs/ldm3/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 576, in run
and self.step()
File "/home/cychan/mambaforge/envs/ldm3/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 540, in step
getattr(self, inst.opname)(inst)
File "/home/cychan/mambaforge/envs/ldm3/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 1792, in RETURN_VALUE
self.output.compile_subgraph(
File "/home/cychan/mambaforge/envs/ldm3/lib/python3.10/site-packages/torch/_dynamo/output_graph.py", line 541, in compile_subgraph
self.compile_and_call_fx_graph(tx, pass2.graph_output_vars(), root)
File "/home/cychan/mambaforge/envs/ldm3/lib/python3.10/site-packages/torch/_dynamo/output_graph.py", line 588, in compile_and_call_fx_graph
compiled_fn = self.call_user_compiler(gm)
File "/home/cychan/mambaforge/envs/ldm3/lib/python3.10/site-packages/torch/_dynamo/utils.py", line 163, in time_wrapper
r = func(*args, **kwargs)
File "/home/cychan/mambaforge/envs/ldm3/lib/python3.10/site-packages/torch/_dynamo/output_graph.py", line 675, in call_user_compiler
raise BackendCompilerFailed(self.compiler_fn, e) from e
torch._dynamo.exc.BackendCompilerFailed: debug_wrapper raised RuntimeError: Input graph did not fail the tester

You can suppress this exception and fall back to eager by setting:
torch._dynamo.config.suppress_errors = True

Minified repro

import os
from math import inf
import torch
from torch import tensor, device
import torch.fx as fx
import functools
import torch._dynamo
from torch._dynamo.debug_utils import run_fwd_maybe_bwd
from torch._dynamo.backends.registry import lookup_backend
from torch._dynamo.testing import rand_strided

import torch._dynamo.config
import torch._inductor.config
import torch._functorch.config
torch._dynamo.config.load_config(b'\x80\x02}q\x00(X\x0b\x00\x00\x00output_codeq\x01\x89X\r\x00\x00\x00log_file_nameq\x02NX\x07\x00\x00\x00verboseq\x03\x88X\x11\x00\x00\x00output_graph_codeq\x04\x89X\x12\x00\x00\x00verify_correctnessq\x05\x89X\x12\x00\x00\x00minimum_call_countq\x06K\x01X\x15\x00\x00\x00dead_code_eliminationq\x07\x88X\x10\x00\x00\x00cache_size_limitq\x08K@X\x14\x00\x00\x00specialize_int_floatq\t\x88X\x0e\x00\x00\x00dynamic_shapesq\n\x89X\x10\x00\x00\x00guard_nn_modulesq\x0b\x89X\x1b\x00\x00\x00traceable_tensor_subclassesq\x0cc__builtin__\nset\nq\r]q\x0e\x85q\x0fRq\x10X\x0f\x00\x00\x00suppress_errorsq\x11\x89X\x15\x00\x00\x00replay_record_enabledq\x12\x88X \x00\x00\x00rewrite_assert_with_torch_assertq\x13\x88X\x12\x00\x00\x00print_graph_breaksq\x14\x89X\x07\x00\x00\x00disableq\x15\x89X*\x00\x00\x00allowed_functions_module_string_ignorelistq\x16h\r]q\x17(X\x0b\x00\x00\x00torch._refsq\x18X\r\x00\x00\x00torch._decompq\x19X\x13\x00\x00\x00torch.distributionsq\x1aX\r\x00\x00\x00torch.testingq\x1bX\x0c\x00\x00\x00torch._primsq\x1ce\x85q\x1dRq\x1eX\x12\x00\x00\x00repro_forward_onlyq\x1f\x89X\x0f\x00\x00\x00repro_toleranceq G?PbM\xd2\xf1\xa9\xfcX\x16\x00\x00\x00capture_scalar_outputsq!\x89X\x19\x00\x00\x00enforce_cond_guards_matchq"\x88X\x0c\x00\x00\x00optimize_ddpq#\x88X\x1a\x00\x00\x00raise_on_ctx_manager_usageq$\x88X\x1c\x00\x00\x00raise_on_unsafe_aot_autogradq%\x89X\x17\x00\x00\x00raise_on_backend_changeq&\x89X\x18\x00\x00\x00error_on_nested_fx_traceq\'\x88X\t\x00\x00\x00allow_rnnq(\x89X\x08\x00\x00\x00base_dirq)X>\x00\x00\x00/home/cychan/mambaforge/envs/ldm3/lib/python3.10/site-packagesq*X\x0e\x00\x00\x00debug_dir_rootq+X%\x00\x00\x00/code/EnlightDiff/torch_compile_debugq,X)\x00\x00\x00DO_NOT_USE_legacy_non_fake_example_inputsq-\x89X\x13\x00\x00\x00_save_config_ignoreq.h\r]q/(X\x0b\x00\x00\x00repro_afterq0X!\x00\x00\x00skipfiles_inline_module_allowlistq1X\x12\x00\x00\x00constant_functionsq2X\x0b\x00\x00\x00repro_levelq3e\x85q4Rq5u.')
torch._inductor.config.load_config(b'\x80\x02}q\x00(X\x05\x00\x00\x00debugq\x01\x89X\x10\x00\x00\x00disable_progressq\x02\x88X\x10\x00\x00\x00verbose_progressq\x03\x89X\x0b\x00\x00\x00cpp_wrapperq\x04\x89X\x03\x00\x00\x00dceq\x05\x89X\x14\x00\x00\x00static_weight_shapesq\x06\x88X\x0c\x00\x00\x00size_assertsq\x07\x88X\x10\x00\x00\x00pick_loop_ordersq\x08\x88X\x0f\x00\x00\x00inplace_buffersq\t\x88X\x11\x00\x00\x00benchmark_harnessq\n\x88X\x0f\x00\x00\x00epilogue_fusionq\x0b\x89X\x15\x00\x00\x00epilogue_fusion_firstq\x0c\x89X\x0f\x00\x00\x00pattern_matcherq\r\x88X\n\x00\x00\x00reorderingq\x0e\x89X\x0c\x00\x00\x00max_autotuneq\x0f\x89X\x17\x00\x00\x00realize_reads_thresholdq\x10K\x04X\x17\x00\x00\x00realize_bytes_thresholdq\x11M\xd0\x07X\x1b\x00\x00\x00realize_acc_reads_thresholdq\x12K\x08X\x0f\x00\x00\x00fallback_randomq\x13\x89X\x12\x00\x00\x00implicit_fallbacksq\x14\x88X\x0b\x00\x00\x00tune_layoutq\x15\x89X\x11\x00\x00\x00aggressive_fusionq\x16\x89X\x0f\x00\x00\x00max_fusion_sizeq\x17K@X\x1b\x00\x00\x00unroll_reductions_thresholdq\x18K\x08X\x0e\x00\x00\x00comment_originq\x19\x89X\x12\x00\x00\x00developer_warningsq\x1a\x89X\x0f\x00\x00\x00compile_threadsq\x1bK X\x13\x00\x00\x00kernel_name_max_opsq\x1cK\nX\r\x00\x00\x00shape_paddingq\x1d\x89X\x0e\x00\x00\x00permute_fusionq\x1e\x89X\x1a\x00\x00\x00profiler_mark_wrapper_callq\x1f\x89X\x18\x00\x00\x00_raise_error_for_testingq \x89X\x0b\x00\x00\x00cpp.threadsq!J\xff\xff\xff\xffX\x13\x00\x00\x00cpp.dynamic_threadsq"\x89X\x0b\x00\x00\x00cpp.simdlenq#NX\x12\x00\x00\x00cpp.min_chunk_sizeq$M\x00\x10X\x07\x00\x00\x00cpp.cxxq%NX\x03\x00\x00\x00g++q&\x86q\'X\x19\x00\x00\x00cpp.enable_kernel_profileq(\x89X\x12\x00\x00\x00cpp.weight_prepackq)\x88X\x11\x00\x00\x00triton.cudagraphsq*\x89X\x17\x00\x00\x00triton.debug_sync_graphq+\x89X\x18\x00\x00\x00triton.debug_sync_kernelq,\x89X\x15\x00\x00\x00triton.dense_indexingq-\x89X\x10\x00\x00\x00triton.max_tilesq.K\x02X\x19\x00\x00\x00triton.autotune_pointwiseq/\x88X\'\x00\x00\x00triton.tiling_prevents_pointwise_fusionq0\x88X\'\x00\x00\x00triton.tiling_prevents_reduction_fusionq1\x88X\x1b\x00\x00\x00triton.ordered_kernel_namesq2\x89X\x1f\x00\x00\x00triton.descriptive_kernel_namesq3\x89X\x1c\x00\x00\x00triton.persistent_reductionsq4\x89X\r\x00\x00\x00trace.enabledq5\x89X\x0f\x00\x00\x00trace.debug_logq6\x88X\x0e\x00\x00\x00trace.info_logq7\x89X\x0e\x00\x00\x00trace.fx_graphq8\x88X\x1a\x00\x00\x00trace.fx_graph_transformedq9\x88X\x13\x00\x00\x00trace.ir_pre_fusionq:\x88X\x14\x00\x00\x00trace.ir_post_fusionq;\x88X\x11\x00\x00\x00trace.output_codeq<\x88X\x13\x00\x00\x00trace.graph_diagramq=\x89X\x15\x00\x00\x00trace.compile_profileq>\x89X\x10\x00\x00\x00trace.upload_tarq?Nu.')
torch._functorch.config.load_config(b'\x80\x02}q\x00(X\x11\x00\x00\x00use_functionalizeq\x01\x88X\x0f\x00\x00\x00use_fake_tensorq\x02\x88X\x16\x00\x00\x00fake_tensor_allow_metaq\x03\x88X\x0c\x00\x00\x00debug_assertq\x04\x88X\x14\x00\x00\x00debug_fake_cross_refq\x05\x89X\x11\x00\x00\x00debug_partitionerq\x06\x89X\x0c\x00\x00\x00debug_graphsq\x07\x89X\x0b\x00\x00\x00debug_jointq\x08\x89X\x12\x00\x00\x00use_dynamic_shapesq\t\x89X\x14\x00\x00\x00static_weight_shapesq\n\x88X\x03\x00\x00\x00cseq\x0b\x88X\x10\x00\x00\x00max_dist_from_bwq\x0cK\x03X\t\x00\x00\x00log_levelq\rK\x14u.')


# REPLACEABLE COMMENT FOR TESTING PURPOSES


args = [((32, 3, 160, 160), (76800, 25600, 160, 1), torch.float32, 'cuda', False), ((32,), (1,), torch.int64, 'cuda', False), ((32, 3, 160, 160), (76800, 25600, 160, 1), torch.float16, 'cuda', False)]
args = [rand_strided(sh, st, dt, dev).requires_grad_(rg) for (sh, st, dt, dev, rg) in args]


from torch.nn import *
class Repro(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.self_mlp_0 = Linear(in_features=64, out_features=256, bias=True).cuda()
        self.self_mlp_2 = Linear(in_features=256, out_features=64, bias=True).cuda()



    def forward(self, x : torch.Tensor, time : torch.Tensor, cond : torch.Tensor):
        arange = torch.arange(32, device = device(type='cuda', index=0))
        mul = arange * -0.2971077539347156;  arange = None
        exp = torch.exp(mul);  mul = None
        getitem = time[(slice(None, None, None), None)];  time = None
        getitem_1 = exp[(None, slice(None, None, None))];  exp = None
        mul_1 = getitem * getitem_1;  getitem = getitem_1 = None
        sin = mul_1.sin()
        cos = mul_1.cos();  mul_1 = None
        cat = torch.cat((sin, cos), dim = -1);  sin = cos = None
        self_mlp_0 = self.self_mlp_0(cat);  cat = None
        softplus = torch._C._nn.softplus(self_mlp_0)
        tanh = torch.tanh(softplus);  softplus = None
        mul_2 = self_mlp_0 * tanh;  self_mlp_0 = tanh = None
        self_mlp_2 = self.self_mlp_2(mul_2);  mul_2 = None
        cat_1 = torch.cat((x, cond), dim = 1);  x = cond = None
        return (cat_1, self_mlp_2)


mod = Repro()

# Setup debug minifier compiler
torch._dynamo.debug_utils.MINIFIER_SPAWNED = True
compiler_fn = lookup_backend("dynamo_minifier_backend")

dynamo_minifier_backend = functools.partial(
    compiler_fn,
    compiler_name="inductor",
)
opt_mod = torch._dynamo.optimize(dynamo_minifier_backend)(mod)

with torch.cuda.amp.autocast(enabled=True):
    opt_mod(*args)

Versions

please refer to #97749

cc @ezyang @soumith @msaroufim @wconstab @ngimel @bdhirsh

@yanboliang yanboliang added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Mar 30, 2023
@yanboliang
Copy link
Contributor

cc @anijain2305 @mlazos

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
module: minifier oncall: pt2 triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

No branches or pull requests

3 participants