-
Notifications
You must be signed in to change notification settings - Fork 25.6k
Closed
Labels
empathy-dayLabel for issues from user empathy daysLabel for issues from user empathy daysmodule: aotdispatchumbrella label for AOTAutograd issuesumbrella label for AOTAutograd issuesmodule: pt2-dispatcherPT2 dispatcher-related issues (e.g., aotdispatch, functionalization, faketensor, custom-op,PT2 dispatcher-related issues (e.g., aotdispatch, functionalization, faketensor, custom-op,oncall: pt2triagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate moduleThis issue has been looked at a team member, and triaged and prioritized into an appropriate module
Description
🐛 Describe the bug
Discovered while compiling lucidrains/audiolm-pytorch
/home/ezyang/local/miniconda3-test/envs/audiolm/lib/python3.10/site-packages/torch/autograd/__init__.py:411: UserWarning: Error detected in BmmBackward0. Traceback of forward call that caused the error:
File "/home/ezyang/local/miniconda3-test/envs/audiolm/lib/python3.10/site-packages/vector_quantize_pytorch-1.14.1-py3.10.egg/vector_quantize_pytorch/vector_quantize_pytorch.py", line 470, in resume_in_forward
dist = -cdist(flatten, embed)
File "/home/ezyang/local/miniconda3-test/envs/audiolm/lib/python3.10/site-packages/vector_quantize_pytorch-1.14.1-py3.10.egg/vector_quantize_pytorch/vector_quantize_pytorch.py", line 32, in cdist
xy = einsum('b i d, b j d -> b i j', x, y) * -2
(Triggered internally at /opt/conda/conda-bld/pytorch_1708025847130/work/torch/csrc/autograd/python_anomaly_mode.cpp:113.)
result = Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass
training with dataset of 2 samples and validating with randomly splitted 1 samples
Traceback (most recent call last):
File "/data/users/ezyang/audiolm-pytorch/demo.py", line 139, in <module>
trainer.train()
File "/data/users/ezyang/audiolm-pytorch/audiolm_pytorch/trainer.py", line 707, in train
logs = self.train_step()
File "/data/users/ezyang/audiolm-pytorch/audiolm_pytorch/trainer.py", line 576, in train_step
loss, (recon_loss, multi_spectral_recon_loss, adversarial_loss, feature_loss, all_commitment_loss) = self.soundstream(wave, return_loss_breakdown = True)
File "/home/ezyang/local/miniconda3-test/envs/audiolm/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1509, in _wrapped_call_impl
return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
File "/home/ezyang/local/miniconda3-test/envs/audiolm/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 489, in _fn
return fn(*args, **kwargs)
File "/home/ezyang/local/miniconda3-test/envs/audiolm/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
return forward_call(*args, **kwargs)
File "/data/users/ezyang/audiolm-pytorch/audiolm_pytorch/soundstream.py", line 830, in forward
x = self.encoder_attn(x)
File "/data/users/ezyang/audiolm-pytorch/audiolm_pytorch/soundstream.py", line 837, in resume_in_forward
x, indices, commit_loss = self.rq(x)
File "/home/ezyang/local/miniconda3-test/envs/audiolm/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/home/ezyang/local/miniconda3-test/envs/audiolm/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
return forward_call(*args, **kwargs)
File "/home/ezyang/local/miniconda3-test/envs/audiolm/lib/python3.10/site-packages/vector_quantize_pytorch-1.14.1-py3.10.egg/vector_quantize_pytorch/residual_vq.py", line 296, in forward
out = tuple(rvq(chunk, indices = chunk_indices, **forward_kwargs) for rvq, chunk, chunk_indices in zip_longest(self.rvqs, x, indices))
File "/home/ezyang/local/miniconda3-test/envs/audiolm/lib/python3.10/site-packages/vector_quantize_pytorch-1.14.1-py3.10.egg/vector_quantize_pytorch/residual_vq.py", line 296, in resume_in_forward
out = tuple(rvq(chunk, indices = chunk_indices, **forward_kwargs) for rvq, chunk, chunk_indices in zip_longest(self.rvqs, x, indices))
File "/home/ezyang/local/miniconda3-test/envs/audiolm/lib/python3.10/site-packages/vector_quantize_pytorch-1.14.1-py3.10.egg/vector_quantize_pytorch/residual_vq.py", line 296, in <genexpr>
out = tuple(rvq(chunk, indices = chunk_indices, **forward_kwargs) for rvq, chunk, chunk_indices in zip_longest(self.rvqs, x, indices))
File "/home/ezyang/local/miniconda3-test/envs/audiolm/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/home/ezyang/local/miniconda3-test/envs/audiolm/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
return forward_call(*args, **kwargs)
File "/home/ezyang/local/miniconda3-test/envs/audiolm/lib/python3.10/site-packages/vector_quantize_pytorch-1.14.1-py3.10.egg/vector_quantize_pytorch/residual_vq.py", line 153, in forward
rand = random.Random(rand_quantize_dropout_fixed_seed) if exists(rand_quantize_dropout_fixed_seed) else random
File "/home/ezyang/local/miniconda3-test/envs/audiolm/lib/python3.10/site-packages/vector_quantize_pytorch-1.14.1-py3.10.egg/vector_quantize_pytorch/residual_vq.py", line 155, in resume_in_forward
rand_quantize_dropout_index = rand.randrange(self.quantize_dropout_cutoff_index, num_quant)
File "/home/ezyang/local/miniconda3-test/envs/audiolm/lib/python3.10/site-packages/vector_quantize_pytorch-1.14.1-py3.10.egg/vector_quantize_pytorch/residual_vq.py", line 177, in resume_in_forward
quantized, *rest = layer(
File "/home/ezyang/local/miniconda3-test/envs/audiolm/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/home/ezyang/local/miniconda3-test/envs/audiolm/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
return forward_call(*args, **kwargs)
File "/home/ezyang/local/miniconda3-test/envs/audiolm/lib/python3.10/site-packages/vector_quantize_pytorch-1.14.1-py3.10.egg/vector_quantize_pytorch/vector_quantize_pytorch.py", line 901, in forward
quantize, embed_ind, distances = self._codebook(x, **codebook_forward_kwargs)
File "/home/ezyang/local/miniconda3-test/envs/audiolm/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/home/ezyang/local/miniconda3-test/envs/audiolm/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
return forward_call(*args, **kwargs)
File "/home/ezyang/local/miniconda3-test/envs/audiolm/lib/python3.10/site-packages/torch/amp/autocast_mode.py", line 16, in decorate_autocast
return func(*args, **kwargs)
File "/home/ezyang/local/miniconda3-test/envs/audiolm/lib/python3.10/site-packages/vector_quantize_pytorch-1.14.1-py3.10.egg/vector_quantize_pytorch/vector_quantize_pytorch.py", line 458, in forward
self.init_embed_(flatten, mask = mask)
File "/home/ezyang/local/miniconda3-test/envs/audiolm/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 655, in catch_errors
return callback(frame, cache_entry, hooks, frame_state)
File "/home/ezyang/local/miniconda3-test/envs/audiolm/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 727, in _convert_frame
result = inner_convert(frame, cache_entry, hooks, frame_state)
File "/home/ezyang/local/miniconda3-test/envs/audiolm/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 383, in _convert_frame_assert
compiled_product = _compile(
File "/home/ezyang/local/miniconda3-test/envs/audiolm/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 646, in _compile
guarded_code = compile_inner(code, one_graph, hooks, transform)
File "/home/ezyang/local/miniconda3-test/envs/audiolm/lib/python3.10/site-packages/torch/_dynamo/utils.py", line 244, in time_wrapper
r = func(*args, **kwargs)
File "/home/ezyang/local/miniconda3-test/envs/audiolm/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 562, in compile_inner
out_code = transform_code_object(code, transform)
File "/home/ezyang/local/miniconda3-test/envs/audiolm/lib/python3.10/site-packages/torch/_dynamo/bytecode_transformation.py", line 1033, in transform_code_object
transformations(instructions, code_options)
File "/home/ezyang/local/miniconda3-test/envs/audiolm/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 151, in _fn
return fn(*args, **kwargs)
File "/home/ezyang/local/miniconda3-test/envs/audiolm/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 527, in transform
tracer.run()
File "/home/ezyang/local/miniconda3-test/envs/audiolm/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 2128, in run
super().run()
File "/home/ezyang/local/miniconda3-test/envs/audiolm/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 818, in run
and self.step()
File "/home/ezyang/local/miniconda3-test/envs/audiolm/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 781, in step
getattr(self, inst.opname)(inst)
File "/home/ezyang/local/miniconda3-test/envs/audiolm/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 465, in wrapper
return handle_graph_break(self, inst, speculation.reason)
File "/home/ezyang/local/miniconda3-test/envs/audiolm/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 521, in handle_graph_break
self.output.compile_subgraph(self, reason=reason)
File "/home/ezyang/local/miniconda3-test/envs/audiolm/lib/python3.10/site-packages/torch/_dynamo/output_graph.py", line 945, in compile_subgraph
self.compile_and_call_fx_graph(tx, pass2.graph_output_vars(), root)
File "/home/ezyang/local/miniconda3-test/envs/audiolm/lib/python3.10/contextlib.py", line 79, in inner
return func(*args, **kwds)
File "/home/ezyang/local/miniconda3-test/envs/audiolm/lib/python3.10/site-packages/torch/_dynamo/output_graph.py", line 1087, in compile_and_call_fx_graph
compiled_fn = self.call_user_compiler(gm)
File "/home/ezyang/local/miniconda3-test/envs/audiolm/lib/python3.10/site-packages/torch/_dynamo/utils.py", line 244, in time_wrapper
r = func(*args, **kwargs)
File "/home/ezyang/local/miniconda3-test/envs/audiolm/lib/python3.10/site-packages/torch/_dynamo/output_graph.py", line 1159, in call_user_compiler
raise BackendCompilerFailed(self.compiler_fn, e).with_traceback(
File "/home/ezyang/local/miniconda3-test/envs/audiolm/lib/python3.10/site-packages/torch/_dynamo/output_graph.py", line 1140, in call_user_compiler
compiled_fn = compiler_fn(gm, self.example_inputs())
File "/home/ezyang/local/miniconda3-test/envs/audiolm/lib/python3.10/site-packages/torch/_dynamo/repro/after_dynamo.py", line 117, in debug_wrapper
compiled_gm = compiler_fn(gm, example_inputs)
File "/home/ezyang/local/miniconda3-test/envs/audiolm/lib/python3.10/site-packages/torch/__init__.py", line 1668, in __call__
return compile_fx(model_, inputs_, config_patches=self.config)
File "/home/ezyang/local/miniconda3-test/envs/audiolm/lib/python3.10/site-packages/torch/_inductor/compile_fx.py", line 1168, in compile_fx
return aot_autograd(
File "/home/ezyang/local/miniconda3-test/envs/audiolm/lib/python3.10/site-packages/torch/_dynamo/backends/common.py", line 55, in compiler_fn
cg = aot_module_simplified(gm, example_inputs, **kwargs)
File "/home/ezyang/local/miniconda3-test/envs/audiolm/lib/python3.10/site-packages/torch/_functorch/aot_autograd.py", line 887, in aot_module_simplified
compiled_fn = create_aot_dispatcher_function(
File "/home/ezyang/local/miniconda3-test/envs/audiolm/lib/python3.10/site-packages/torch/_dynamo/utils.py", line 244, in time_wrapper
r = func(*args, **kwargs)
File "/home/ezyang/local/miniconda3-test/envs/audiolm/lib/python3.10/site-packages/torch/_functorch/aot_autograd.py", line 600, in create_aot_dispatcher_function
compiled_fn = compiler_fn(flat_fn, fake_flat_args, aot_config, fw_metadata=fw_metadata)
File "/home/ezyang/local/miniconda3-test/envs/audiolm/lib/python3.10/site-packages/torch/_functorch/_aot_autograd/runtime_wrappers.py", line 425, in aot_wrapper_dedupe
return compiler_fn(flat_fn, leaf_flat_args, aot_config, fw_metadata=fw_metadata)
File "/home/ezyang/local/miniconda3-test/envs/audiolm/lib/python3.10/site-packages/torch/_functorch/_aot_autograd/runtime_wrappers.py", line 630, in aot_wrapper_synthetic_base
return compiler_fn(flat_fn, flat_args, aot_config, fw_metadata=fw_metadata)
File "/home/ezyang/local/miniconda3-test/envs/audiolm/lib/python3.10/site-packages/torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py", line 151, in aot_dispatch_autograd
fx_g, joint_inputs, maybe_subclass_meta = aot_dispatch_autograd_graph( # type: ignore[misc]
File "/home/ezyang/local/miniconda3-test/envs/audiolm/lib/python3.10/site-packages/torch/_functorch/_aot_autograd/dispatch_and_compile_graph.py", line 159, in aot_dispatch_autograd_graph
fx_g = _create_graph(joint_fn_to_trace, updated_joint_inputs, aot_config=aot_config)
File "/home/ezyang/local/miniconda3-test/envs/audiolm/lib/python3.10/site-packages/torch/_functorch/_aot_autograd/dispatch_and_compile_graph.py", line 32, in _create_graph
fx_g = make_fx(f, decomposition_table=aot_config.decompositions)(*args)
File "/home/ezyang/local/miniconda3-test/envs/audiolm/lib/python3.10/site-packages/torch/fx/experimental/proxy_tensor.py", line 871, in wrapped
t = dispatch_trace(wrap_key(func, args, fx_tracer, pre_dispatch), tracer=fx_tracer, concrete_args=tuple(phs))
File "/home/ezyang/local/miniconda3-test/envs/audiolm/lib/python3.10/site-packages/torch/_compile.py", line 24, in inner
return torch._dynamo.disable(fn, recursive)(*args, **kwargs)
File "/home/ezyang/local/miniconda3-test/envs/audiolm/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 489, in _fn
return fn(*args, **kwargs)
File "/home/ezyang/local/miniconda3-test/envs/audiolm/lib/python3.10/site-packages/torch/_dynamo/external_utils.py", line 17, in inner
return fn(*args, **kwargs)
File "/home/ezyang/local/miniconda3-test/envs/audiolm/lib/python3.10/site-packages/torch/fx/experimental/proxy_tensor.py", line 483, in dispatch_trace
graph = tracer.trace(root, concrete_args)
File "/home/ezyang/local/miniconda3-test/envs/audiolm/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 489, in _fn
return fn(*args, **kwargs)
File "/home/ezyang/local/miniconda3-test/envs/audiolm/lib/python3.10/site-packages/torch/_dynamo/external_utils.py", line 17, in inner
return fn(*args, **kwargs)
File "/home/ezyang/local/miniconda3-test/envs/audiolm/lib/python3.10/site-packages/torch/fx/_symbolic_trace.py", line 821, in trace
(self.create_arg(fn(*args)),),
File "/home/ezyang/local/miniconda3-test/envs/audiolm/lib/python3.10/site-packages/torch/fx/_symbolic_trace.py", line 688, in flatten_fn
tree_out = root_fn(*tree_args)
File "/home/ezyang/local/miniconda3-test/envs/audiolm/lib/python3.10/site-packages/torch/fx/experimental/proxy_tensor.py", line 519, in wrapped
out = f(*tensors)
File "/home/ezyang/local/miniconda3-test/envs/audiolm/lib/python3.10/site-packages/torch/_functorch/_aot_autograd/traced_function_transforms.py", line 463, in joint_helper
return _functionalized_f_helper(primals, tangents)
File "/home/ezyang/local/miniconda3-test/envs/audiolm/lib/python3.10/site-packages/torch/_functorch/_aot_autograd/traced_function_transforms.py", line 355, in _functionalized_f_helper
f_outs = fn(*f_args)
File "/home/ezyang/local/miniconda3-test/envs/audiolm/lib/python3.10/site-packages/torch/_functorch/_aot_autograd/traced_function_transforms.py", line 250, in inner_fn_with_anomaly
return inner_fn(*args)
File "/home/ezyang/local/miniconda3-test/envs/audiolm/lib/python3.10/site-packages/torch/_functorch/_aot_autograd/traced_function_transforms.py", line 235, in inner_fn
backward_out = torch.autograd.grad(
File "/home/ezyang/local/miniconda3-test/envs/audiolm/lib/python3.10/site-packages/torch/autograd/__init__.py", line 411, in grad
result = Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass
torch._dynamo.exc.BackendCompilerFailed: backend='inductor' raised:
RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [torch.cuda.FloatTensor [1, 512, 1024]] is at version 1; expected version 0 instead. Hint: the backtrace further above shows the operation that failed to compute its gradient. The variable in question was changed in there or anywhere later. Good luck!
Here's the Dynamo output graph we try to run in AOTAutograd https://gist.github.com/ezyang/c7e157dc9f8314794e90e4c314eb3c21
Full repro code: gist.github.com/ezyang/64c24c9fc5529f3afed4ee4266f6adc5
Versions
run
Metadata
Metadata
Assignees
Labels
empathy-dayLabel for issues from user empathy daysLabel for issues from user empathy daysmodule: aotdispatchumbrella label for AOTAutograd issuesumbrella label for AOTAutograd issuesmodule: pt2-dispatcherPT2 dispatcher-related issues (e.g., aotdispatch, functionalization, faketensor, custom-op,PT2 dispatcher-related issues (e.g., aotdispatch, functionalization, faketensor, custom-op,oncall: pt2triagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate moduleThis issue has been looked at a team member, and triaged and prioritized into an appropriate module