Skip to content

[AudioLM] AOTAutograd: RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation #121353

@ezyang

Description

@ezyang

🐛 Describe the bug

Discovered while compiling lucidrains/audiolm-pytorch

/home/ezyang/local/miniconda3-test/envs/audiolm/lib/python3.10/site-packages/torch/autograd/__init__.py:411: UserWarning: Error detected in BmmBackward0. Traceback of forward call that caused the error:
  File "/home/ezyang/local/miniconda3-test/envs/audiolm/lib/python3.10/site-packages/vector_quantize_pytorch-1.14.1-py3.10.egg/vector_quantize_pytorch/vector_quantize_pytorch.py", line 470, in resume_in_forward
    dist = -cdist(flatten, embed)
  File "/home/ezyang/local/miniconda3-test/envs/audiolm/lib/python3.10/site-packages/vector_quantize_pytorch-1.14.1-py3.10.egg/vector_quantize_pytorch/vector_quantize_pytorch.py", line 32, in cdist
    xy = einsum('b i d, b j d -> b i j', x, y) * -2
 (Triggered internally at /opt/conda/conda-bld/pytorch_1708025847130/work/torch/csrc/autograd/python_anomaly_mode.cpp:113.)
  result = Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
training with dataset of 2 samples and validating with randomly splitted 1 samples
Traceback (most recent call last):
  File "/data/users/ezyang/audiolm-pytorch/demo.py", line 139, in <module>
    trainer.train()
  File "/data/users/ezyang/audiolm-pytorch/audiolm_pytorch/trainer.py", line 707, in train
    logs = self.train_step()
  File "/data/users/ezyang/audiolm-pytorch/audiolm_pytorch/trainer.py", line 576, in train_step
    loss, (recon_loss, multi_spectral_recon_loss, adversarial_loss, feature_loss, all_commitment_loss) = self.soundstream(wave, return_loss_breakdown = True)
  File "/home/ezyang/local/miniconda3-test/envs/audiolm/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1509, in _wrapped_call_impl
    return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
  File "/home/ezyang/local/miniconda3-test/envs/audiolm/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 489, in _fn
    return fn(*args, **kwargs)
  File "/home/ezyang/local/miniconda3-test/envs/audiolm/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/data/users/ezyang/audiolm-pytorch/audiolm_pytorch/soundstream.py", line 830, in forward
    x = self.encoder_attn(x)
  File "/data/users/ezyang/audiolm-pytorch/audiolm_pytorch/soundstream.py", line 837, in resume_in_forward
    x, indices, commit_loss = self.rq(x)
  File "/home/ezyang/local/miniconda3-test/envs/audiolm/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/ezyang/local/miniconda3-test/envs/audiolm/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/ezyang/local/miniconda3-test/envs/audiolm/lib/python3.10/site-packages/vector_quantize_pytorch-1.14.1-py3.10.egg/vector_quantize_pytorch/residual_vq.py", line 296, in forward
    out = tuple(rvq(chunk, indices = chunk_indices, **forward_kwargs) for rvq, chunk, chunk_indices in zip_longest(self.rvqs, x, indices))
  File "/home/ezyang/local/miniconda3-test/envs/audiolm/lib/python3.10/site-packages/vector_quantize_pytorch-1.14.1-py3.10.egg/vector_quantize_pytorch/residual_vq.py", line 296, in resume_in_forward
    out = tuple(rvq(chunk, indices = chunk_indices, **forward_kwargs) for rvq, chunk, chunk_indices in zip_longest(self.rvqs, x, indices))
  File "/home/ezyang/local/miniconda3-test/envs/audiolm/lib/python3.10/site-packages/vector_quantize_pytorch-1.14.1-py3.10.egg/vector_quantize_pytorch/residual_vq.py", line 296, in <genexpr>
    out = tuple(rvq(chunk, indices = chunk_indices, **forward_kwargs) for rvq, chunk, chunk_indices in zip_longest(self.rvqs, x, indices))
  File "/home/ezyang/local/miniconda3-test/envs/audiolm/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/ezyang/local/miniconda3-test/envs/audiolm/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/ezyang/local/miniconda3-test/envs/audiolm/lib/python3.10/site-packages/vector_quantize_pytorch-1.14.1-py3.10.egg/vector_quantize_pytorch/residual_vq.py", line 153, in forward
    rand = random.Random(rand_quantize_dropout_fixed_seed) if exists(rand_quantize_dropout_fixed_seed) else random
  File "/home/ezyang/local/miniconda3-test/envs/audiolm/lib/python3.10/site-packages/vector_quantize_pytorch-1.14.1-py3.10.egg/vector_quantize_pytorch/residual_vq.py", line 155, in resume_in_forward
    rand_quantize_dropout_index = rand.randrange(self.quantize_dropout_cutoff_index, num_quant)
  File "/home/ezyang/local/miniconda3-test/envs/audiolm/lib/python3.10/site-packages/vector_quantize_pytorch-1.14.1-py3.10.egg/vector_quantize_pytorch/residual_vq.py", line 177, in resume_in_forward
    quantized, *rest = layer(
  File "/home/ezyang/local/miniconda3-test/envs/audiolm/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/ezyang/local/miniconda3-test/envs/audiolm/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/ezyang/local/miniconda3-test/envs/audiolm/lib/python3.10/site-packages/vector_quantize_pytorch-1.14.1-py3.10.egg/vector_quantize_pytorch/vector_quantize_pytorch.py", line 901, in forward
    quantize, embed_ind, distances = self._codebook(x, **codebook_forward_kwargs)
  File "/home/ezyang/local/miniconda3-test/envs/audiolm/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/ezyang/local/miniconda3-test/envs/audiolm/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/ezyang/local/miniconda3-test/envs/audiolm/lib/python3.10/site-packages/torch/amp/autocast_mode.py", line 16, in decorate_autocast
    return func(*args, **kwargs)
  File "/home/ezyang/local/miniconda3-test/envs/audiolm/lib/python3.10/site-packages/vector_quantize_pytorch-1.14.1-py3.10.egg/vector_quantize_pytorch/vector_quantize_pytorch.py", line 458, in forward
    self.init_embed_(flatten, mask = mask)
  File "/home/ezyang/local/miniconda3-test/envs/audiolm/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 655, in catch_errors
    return callback(frame, cache_entry, hooks, frame_state)
  File "/home/ezyang/local/miniconda3-test/envs/audiolm/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 727, in _convert_frame
    result = inner_convert(frame, cache_entry, hooks, frame_state)
  File "/home/ezyang/local/miniconda3-test/envs/audiolm/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 383, in _convert_frame_assert
    compiled_product = _compile(
  File "/home/ezyang/local/miniconda3-test/envs/audiolm/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 646, in _compile
    guarded_code = compile_inner(code, one_graph, hooks, transform)
  File "/home/ezyang/local/miniconda3-test/envs/audiolm/lib/python3.10/site-packages/torch/_dynamo/utils.py", line 244, in time_wrapper
    r = func(*args, **kwargs)
  File "/home/ezyang/local/miniconda3-test/envs/audiolm/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 562, in compile_inner
    out_code = transform_code_object(code, transform)
  File "/home/ezyang/local/miniconda3-test/envs/audiolm/lib/python3.10/site-packages/torch/_dynamo/bytecode_transformation.py", line 1033, in transform_code_object
    transformations(instructions, code_options)
  File "/home/ezyang/local/miniconda3-test/envs/audiolm/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 151, in _fn
    return fn(*args, **kwargs)
  File "/home/ezyang/local/miniconda3-test/envs/audiolm/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 527, in transform
    tracer.run()
  File "/home/ezyang/local/miniconda3-test/envs/audiolm/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 2128, in run
    super().run()
  File "/home/ezyang/local/miniconda3-test/envs/audiolm/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 818, in run
    and self.step()
  File "/home/ezyang/local/miniconda3-test/envs/audiolm/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 781, in step
    getattr(self, inst.opname)(inst)
  File "/home/ezyang/local/miniconda3-test/envs/audiolm/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 465, in wrapper
    return handle_graph_break(self, inst, speculation.reason)
  File "/home/ezyang/local/miniconda3-test/envs/audiolm/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 521, in handle_graph_break
    self.output.compile_subgraph(self, reason=reason)
  File "/home/ezyang/local/miniconda3-test/envs/audiolm/lib/python3.10/site-packages/torch/_dynamo/output_graph.py", line 945, in compile_subgraph
    self.compile_and_call_fx_graph(tx, pass2.graph_output_vars(), root)
  File "/home/ezyang/local/miniconda3-test/envs/audiolm/lib/python3.10/contextlib.py", line 79, in inner
    return func(*args, **kwds)
  File "/home/ezyang/local/miniconda3-test/envs/audiolm/lib/python3.10/site-packages/torch/_dynamo/output_graph.py", line 1087, in compile_and_call_fx_graph
    compiled_fn = self.call_user_compiler(gm)
  File "/home/ezyang/local/miniconda3-test/envs/audiolm/lib/python3.10/site-packages/torch/_dynamo/utils.py", line 244, in time_wrapper
    r = func(*args, **kwargs)
  File "/home/ezyang/local/miniconda3-test/envs/audiolm/lib/python3.10/site-packages/torch/_dynamo/output_graph.py", line 1159, in call_user_compiler
    raise BackendCompilerFailed(self.compiler_fn, e).with_traceback(
  File "/home/ezyang/local/miniconda3-test/envs/audiolm/lib/python3.10/site-packages/torch/_dynamo/output_graph.py", line 1140, in call_user_compiler
    compiled_fn = compiler_fn(gm, self.example_inputs())
  File "/home/ezyang/local/miniconda3-test/envs/audiolm/lib/python3.10/site-packages/torch/_dynamo/repro/after_dynamo.py", line 117, in debug_wrapper
    compiled_gm = compiler_fn(gm, example_inputs)
  File "/home/ezyang/local/miniconda3-test/envs/audiolm/lib/python3.10/site-packages/torch/__init__.py", line 1668, in __call__
    return compile_fx(model_, inputs_, config_patches=self.config)
  File "/home/ezyang/local/miniconda3-test/envs/audiolm/lib/python3.10/site-packages/torch/_inductor/compile_fx.py", line 1168, in compile_fx
    return aot_autograd(
  File "/home/ezyang/local/miniconda3-test/envs/audiolm/lib/python3.10/site-packages/torch/_dynamo/backends/common.py", line 55, in compiler_fn
    cg = aot_module_simplified(gm, example_inputs, **kwargs)
  File "/home/ezyang/local/miniconda3-test/envs/audiolm/lib/python3.10/site-packages/torch/_functorch/aot_autograd.py", line 887, in aot_module_simplified
    compiled_fn = create_aot_dispatcher_function(
  File "/home/ezyang/local/miniconda3-test/envs/audiolm/lib/python3.10/site-packages/torch/_dynamo/utils.py", line 244, in time_wrapper
    r = func(*args, **kwargs)
  File "/home/ezyang/local/miniconda3-test/envs/audiolm/lib/python3.10/site-packages/torch/_functorch/aot_autograd.py", line 600, in create_aot_dispatcher_function
    compiled_fn = compiler_fn(flat_fn, fake_flat_args, aot_config, fw_metadata=fw_metadata)
  File "/home/ezyang/local/miniconda3-test/envs/audiolm/lib/python3.10/site-packages/torch/_functorch/_aot_autograd/runtime_wrappers.py", line 425, in aot_wrapper_dedupe
    return compiler_fn(flat_fn, leaf_flat_args, aot_config, fw_metadata=fw_metadata)
  File "/home/ezyang/local/miniconda3-test/envs/audiolm/lib/python3.10/site-packages/torch/_functorch/_aot_autograd/runtime_wrappers.py", line 630, in aot_wrapper_synthetic_base
    return compiler_fn(flat_fn, flat_args, aot_config, fw_metadata=fw_metadata)
  File "/home/ezyang/local/miniconda3-test/envs/audiolm/lib/python3.10/site-packages/torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py", line 151, in aot_dispatch_autograd
    fx_g, joint_inputs, maybe_subclass_meta = aot_dispatch_autograd_graph(  # type: ignore[misc]
  File "/home/ezyang/local/miniconda3-test/envs/audiolm/lib/python3.10/site-packages/torch/_functorch/_aot_autograd/dispatch_and_compile_graph.py", line 159, in aot_dispatch_autograd_graph
    fx_g = _create_graph(joint_fn_to_trace, updated_joint_inputs, aot_config=aot_config)
  File "/home/ezyang/local/miniconda3-test/envs/audiolm/lib/python3.10/site-packages/torch/_functorch/_aot_autograd/dispatch_and_compile_graph.py", line 32, in _create_graph
    fx_g = make_fx(f, decomposition_table=aot_config.decompositions)(*args)
  File "/home/ezyang/local/miniconda3-test/envs/audiolm/lib/python3.10/site-packages/torch/fx/experimental/proxy_tensor.py", line 871, in wrapped
    t = dispatch_trace(wrap_key(func, args, fx_tracer, pre_dispatch), tracer=fx_tracer, concrete_args=tuple(phs))
  File "/home/ezyang/local/miniconda3-test/envs/audiolm/lib/python3.10/site-packages/torch/_compile.py", line 24, in inner
    return torch._dynamo.disable(fn, recursive)(*args, **kwargs)
  File "/home/ezyang/local/miniconda3-test/envs/audiolm/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 489, in _fn
    return fn(*args, **kwargs)
  File "/home/ezyang/local/miniconda3-test/envs/audiolm/lib/python3.10/site-packages/torch/_dynamo/external_utils.py", line 17, in inner
    return fn(*args, **kwargs)
  File "/home/ezyang/local/miniconda3-test/envs/audiolm/lib/python3.10/site-packages/torch/fx/experimental/proxy_tensor.py", line 483, in dispatch_trace
    graph = tracer.trace(root, concrete_args)
  File "/home/ezyang/local/miniconda3-test/envs/audiolm/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 489, in _fn
    return fn(*args, **kwargs)
  File "/home/ezyang/local/miniconda3-test/envs/audiolm/lib/python3.10/site-packages/torch/_dynamo/external_utils.py", line 17, in inner
    return fn(*args, **kwargs)
  File "/home/ezyang/local/miniconda3-test/envs/audiolm/lib/python3.10/site-packages/torch/fx/_symbolic_trace.py", line 821, in trace
    (self.create_arg(fn(*args)),),
  File "/home/ezyang/local/miniconda3-test/envs/audiolm/lib/python3.10/site-packages/torch/fx/_symbolic_trace.py", line 688, in flatten_fn
    tree_out = root_fn(*tree_args)
  File "/home/ezyang/local/miniconda3-test/envs/audiolm/lib/python3.10/site-packages/torch/fx/experimental/proxy_tensor.py", line 519, in wrapped
    out = f(*tensors)
  File "/home/ezyang/local/miniconda3-test/envs/audiolm/lib/python3.10/site-packages/torch/_functorch/_aot_autograd/traced_function_transforms.py", line 463, in joint_helper
    return _functionalized_f_helper(primals, tangents)
  File "/home/ezyang/local/miniconda3-test/envs/audiolm/lib/python3.10/site-packages/torch/_functorch/_aot_autograd/traced_function_transforms.py", line 355, in _functionalized_f_helper
    f_outs = fn(*f_args)
  File "/home/ezyang/local/miniconda3-test/envs/audiolm/lib/python3.10/site-packages/torch/_functorch/_aot_autograd/traced_function_transforms.py", line 250, in inner_fn_with_anomaly
    return inner_fn(*args)
  File "/home/ezyang/local/miniconda3-test/envs/audiolm/lib/python3.10/site-packages/torch/_functorch/_aot_autograd/traced_function_transforms.py", line 235, in inner_fn
    backward_out = torch.autograd.grad(
  File "/home/ezyang/local/miniconda3-test/envs/audiolm/lib/python3.10/site-packages/torch/autograd/__init__.py", line 411, in grad
    result = Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
torch._dynamo.exc.BackendCompilerFailed: backend='inductor' raised:
RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [torch.cuda.FloatTensor [1, 512, 1024]] is at version 1; expected version 0 instead. Hint: the backtrace further above shows the operation that failed to compute its gradient. The variable in question was changed in there or anywhere later. Good luck!

Here's the Dynamo output graph we try to run in AOTAutograd https://gist.github.com/ezyang/c7e157dc9f8314794e90e4c314eb3c21

Full repro code: gist.github.com/ezyang/64c24c9fc5529f3afed4ee4266f6adc5

Versions

run

cc @msaroufim @bdhirsh @anijain2305 @zou3519 @chauhang

Metadata

Metadata

Assignees

Labels

empathy-dayLabel for issues from user empathy daysmodule: aotdispatchumbrella label for AOTAutograd issuesmodule: pt2-dispatcherPT2 dispatcher-related issues (e.g., aotdispatch, functionalization, faketensor, custom-op,oncall: pt2triagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions