Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

torch.compile fails when applied to tensor views that have been modified by in-place operators #113793

Open
hjmshi opened this issue Nov 15, 2023 · 1 comment
Labels
module: aotdispatch umbrella label for AOTAutograd issues oncall: pt2 triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@hjmshi
Copy link
Contributor

hjmshi commented Nov 15, 2023

馃悰 Describe the bug

When running torch.compile on a function that takes a view of a tensor and modifies it in-place, we run into the runtime error:

RuntimeError: A view was created in no_grad mode and its base or another view of its base has been modified inplace with grad mode enabled. Given that this use case is ambiguous and error-prone, it is forbidden. You can clarify your code by moving both the view and the inplace either both inside the no_grad block (if you don't want the inplace to be tracked) or both outside (if you want the inplace to be tracked).

This is needed when applying torch.compile onto the step function of our Distributed Shampoo implementation, where we similarly create views of the parameters and update them in-place.

Interestingly, when we run this code on an earlier version of PyTorch (2.0.1), we did not observe this issue. If we remove the return param_view, we also do not observe the error.

This issue is related to the following issues: #99968 (comment) #11390 (comment).

cc: @tsunghsienlee @shintaro-iwasaki @minddrummer @csmiler @mlazos @bdhirsh @yuchenhao

Error logs

backend='aot_eager' raised:
RuntimeError: A view was created in no_grad mode and its base or another view of its base has been modified inplace with grad mode enabled. Given that this use case is ambiguous and error-prone, it is forbidden. You can clarify your code by moving both the view and the inplace either both inside the no_grad block (if you don't want the inplace to be tracked) or both outside (if you want the inplace to be tracked).
Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information
You can suppress this exception and fall back to eager by setting:
    import torch._dynamo
    torch._dynamo.config.suppress_errors = True
---------------------------------------------------------------------------
BackendCompilerFailed                     Traceback (most recent call last)
Cell In[1], line 13
     11 a = torch.tensor([1.0], requires_grad=True)
     12 for i in range(2):
---> 13     (compute_in_place_add(a))
File /mnt/xarfuse/uid-27458/0b399ed9-seed-nspid4026533690_cgpid4392124-ns-4026533687/torch/utils/_contextlib.py:115, in context_decorator.<locals>.decorate_context(*args, **kwargs)
    112 @functools.wraps(func)
    113 def decorate_context(*args, **kwargs):
    114     with ctx_factory():
--> 115         return func(*args, **kwargs)
File /mnt/xarfuse/uid-27458/0b399ed9-seed-nspid4026533690_cgpid4392124-ns-4026533687/torch/_dynamo/eval_frame.py:408, in _TorchDynamoContext.__call__.<locals>._fn(*args, **kwargs)
    406 dynamic_ctx.__enter__()
    407 try:
--> 408     return fn(*args, **kwargs)
    409 finally:
    410     set_eval_frame(prior)
File /mnt/xarfuse/uid-27458/0b399ed9-seed-nspid4026533690_cgpid4392124-ns-4026533687/torch/_dynamo/eval_frame.py:569, in catch_errors_wrapper.<locals>.catch_errors(frame, cache_entry, frame_state)
    566             return hijacked_callback(frame, cache_entry, hooks, frame_state)
    568 with compile_lock, _disable_current_modes():
--> 569     return callback(frame, cache_entry, hooks, frame_state)
File /mnt/xarfuse/uid-27458/0b399ed9-seed-nspid4026533690_cgpid4392124-ns-4026533687/torch/_dynamo/convert_frame.py:671, in convert_frame.<locals>._convert_frame(frame, cache_entry, hooks, frame_state)
    669 counters["frames"]["total"] += 1
    670 try:
--> 671     result = inner_convert(frame, cache_entry, hooks, frame_state)
    672     counters["frames"]["ok"] += 1
    673     return result
File /mnt/xarfuse/uid-27458/0b399ed9-seed-nspid4026533690_cgpid4392124-ns-4026533687/torch/_dynamo/convert_frame.py:377, in convert_frame_assert.<locals>._convert_frame_assert(frame, cache_entry, hooks, frame_state)
    363 compile_id = CompileId(frame_id, frame_compile_id)
    365 signpost_event(
    366     "dynamo",
    367     "_convert_frame_assert._compile",
   (...)
    374     },
    375 )
--> 377 return _compile(
    378     frame.f_code,
    379     frame.f_globals,
    380     frame.f_locals,
    381     frame.f_builtins,
    382     compiler_fn,
    383     one_graph,
    384     export,
    385     export_constraints,
    386     hooks,
    387     cache_size,
    388     frame,
    389     frame_state=frame_state,
    390     compile_id=compile_id,
    391 )
File /mnt/xarfuse/uid-27458/0b399ed9-seed-nspid4026533690_cgpid4392124-ns-4026533687/torch/_dynamo/convert_frame.py:595, in _compile(code, globals, locals, builtins, compiler_fn, one_graph, export, export_constraints, hooks, cache_size, frame, frame_state, compile_id)
    593 with compile_context(CompileContext(compile_id)):
    594     try:
--> 595         guarded_code = compile_inner(code, one_graph, hooks, transform)
    596         return guarded_code
    597     except (
    598         Unsupported,
    599         TorchRuntimeError,
   (...)
    606         BisectValidationException,
    607     ) as e:
File /mnt/xarfuse/uid-27458/0b399ed9-seed-nspid4026533690_cgpid4392124-ns-4026533687/torch/_dynamo/utils.py:243, in dynamo_timed.<locals>.dynamo_timed_inner.<locals>.time_wrapper(*args, **kwargs)
    241 with torch.profiler.record_function(f"{key} (dynamo_timed)"):
    242     t0 = time.time()
--> 243     r = func(*args, **kwargs)
    244     time_spent = time.time() - t0
    245 compilation_time_metrics[key].append(time_spent)
File /mnt/xarfuse/uid-27458/0b399ed9-seed-nspid4026533690_cgpid4392124-ns-4026533687/torch/_dynamo/convert_frame.py:512, in _compile.<locals>.compile_inner(code, one_graph, hooks, transform)
    510 CompileContext.get().attempt = attempt
    511 try:
--> 512     out_code = transform_code_object(code, transform)
    513     orig_code_map[out_code] = code
    514     break
File /mnt/xarfuse/uid-27458/0b399ed9-seed-nspid4026533690_cgpid4392124-ns-4026533687/torch/_dynamo/bytecode_transformation.py:1033, in transform_code_object(code, transformations, safe)
   1030 instructions = cleaned_instructions(code, safe)
   1031 propagate_line_nums(instructions)
-> 1033 transformations(instructions, code_options)
   1034 return clean_and_assemble_instructions(instructions, keys, code_options)[1]
File /mnt/xarfuse/uid-27458/0b399ed9-seed-nspid4026533690_cgpid4392124-ns-4026533687/torch/_dynamo/convert_frame.py:150, in preserve_global_state.<locals>._fn(*args, **kwargs)
    148 cleanup = setup_compile_debug()
    149 try:
--> 150     return fn(*args, **kwargs)
    151 finally:
    152     cleanup.close()
File /mnt/xarfuse/uid-27458/0b399ed9-seed-nspid4026533690_cgpid4392124-ns-4026533687/torch/_dynamo/convert_frame.py:477, in _compile.<locals>.transform(instructions, code_options)
    475 try:
    476     with tracing(tracer.output.tracing_context), tracer.set_current_tx():
--> 477         tracer.run()
    478 except exc.UnspecializeRestartAnalysis:
    479     speculation_log.clear()
File /mnt/xarfuse/uid-27458/0b399ed9-seed-nspid4026533690_cgpid4392124-ns-4026533687/torch/_dynamo/symbolic_convert.py:2120, in InstructionTranslator.run(self)
   2119 def run(self):
-> 2120     super().run()
File /mnt/xarfuse/uid-27458/0b399ed9-seed-nspid4026533690_cgpid4392124-ns-4026533687/torch/_dynamo/symbolic_convert.py:815, in InstructionTranslatorBase.run(self)
    810 try:
    811     self.output.push_tx(self)
    812     while (
    813         self.instruction_pointer is not None
    814         and not self.output.should_exit
--> 815         and self.step()
    816     ):
    817         pass
    818 except BackendCompilerFailed:
File /mnt/xarfuse/uid-27458/0b399ed9-seed-nspid4026533690_cgpid4392124-ns-4026533687/torch/_dynamo/symbolic_convert.py:778, in InstructionTranslatorBase.step(self)
    774         unimplemented(f"missing: {inst.opname}")
    775     TracingContext.set_current_loc(
    776         self.f_code.co_filename, self.lineno, self.f_code.co_name
    777     )
--> 778     getattr(self, inst.opname)(inst)
    780     return inst.opname != "RETURN_VALUE"
    781 except Unsupported:
File /mnt/xarfuse/uid-27458/0b399ed9-seed-nspid4026533690_cgpid4392124-ns-4026533687/torch/_dynamo/symbolic_convert.py:2235, in InstructionTranslator.RETURN_VALUE(self, inst)
   2230 _step_logger()(
   2231     logging.INFO,
   2232     f"torchdynamo done tracing {self.f_code.co_name} (RETURN_VALUE)",
   2233 )
   2234 log.debug("RETURN_VALUE triggered compile")
-> 2235 self.output.compile_subgraph(
   2236     self,
   2237     reason=GraphCompileReason(
   2238         "return_value", [self.frame_summary()], graph_break=False
   2239     ),
   2240 )
   2241 self.output.add_output_instructions([create_instruction("RETURN_VALUE")])
File /mnt/xarfuse/uid-27458/0b399ed9-seed-nspid4026533690_cgpid4392124-ns-4026533687/torch/_dynamo/output_graph.py:879, in OutputGraph.compile_subgraph(self, tx, partial_convert, reason)
    876     append_prefix_insts()
    877     # optimization to generate better code in a common case
    878     self.add_output_instructions(
--> 879         self.compile_and_call_fx_graph(tx, list(reversed(stack_values)), root)
    880         + [create_instruction("UNPACK_SEQUENCE", arg=len(stack_values))]
    881     )
    882 else:
    883     graph_output_var = self.new_var("graph_out")
File /mnt/xarfuse/uid-27458/0b399ed9-seed-nspid4026533690_cgpid4392124-ns-4026533687/runtime/lib/python3.10/contextlib.py:79, in ContextDecorator.__call__.<locals>.inner(*args, **kwds)
     76 @wraps(func)
     77 def inner(*args, **kwds):
     78     with self._recreate_cm():
---> 79         return func(*args, **kwds)
File /mnt/xarfuse/uid-27458/0b399ed9-seed-nspid4026533690_cgpid4392124-ns-4026533687/torch/_dynamo/output_graph.py:1024, in OutputGraph.compile_and_call_fx_graph(self, tx, rv, root)
   1022 self.call_cleanup_hooks()
   1023 with self.restore_global_state():
-> 1024     compiled_fn = self.call_user_compiler(gm)
   1025 compiled_fn = disable(compiled_fn)
   1027 counters["stats"]["unique_graphs"] += 1
File /mnt/xarfuse/uid-27458/0b399ed9-seed-nspid4026533690_cgpid4392124-ns-4026533687/torch/_dynamo/utils.py:243, in dynamo_timed.<locals>.dynamo_timed_inner.<locals>.time_wrapper(*args, **kwargs)
    241 with torch.profiler.record_function(f"{key} (dynamo_timed)"):
    242     t0 = time.time()
--> 243     r = func(*args, **kwargs)
    244     time_spent = time.time() - t0
    245 compilation_time_metrics[key].append(time_spent)
File /mnt/xarfuse/uid-27458/0b399ed9-seed-nspid4026533690_cgpid4392124-ns-4026533687/torch/_dynamo/output_graph.py:1096, in OutputGraph.call_user_compiler(self, gm)
   1094     raise e
   1095 except Exception as e:
-> 1096     raise BackendCompilerFailed(self.compiler_fn, e).with_traceback(
   1097         e.__traceback__
   1098     ) from None
   1100 signpost_event(
   1101     "dynamo",
   1102     "OutputGraph.call_user_compiler",
   (...)
   1108     },
   1109 )
   1111 return compiled_fn
File /mnt/xarfuse/uid-27458/0b399ed9-seed-nspid4026533690_cgpid4392124-ns-4026533687/torch/_dynamo/output_graph.py:1077, in OutputGraph.call_user_compiler(self, gm)
   1075 if config.verify_correctness:
   1076     compiler_fn = WrapperBackend(compiler_fn)
-> 1077 compiled_fn = compiler_fn(gm, self.example_inputs())
   1078 _step_logger()(logging.INFO, f"done compiler function {name}")
   1079 assert callable(compiled_fn), "compiler_fn did not return callable"
File /mnt/xarfuse/uid-27458/0b399ed9-seed-nspid4026533690_cgpid4392124-ns-4026533687/torch/_dynamo/repro/after_dynamo.py:117, in wrap_backend_debug.<locals>.debug_wrapper(gm, example_inputs, **kwargs)
    115             raise
    116 else:
--> 117     compiled_gm = compiler_fn(gm, example_inputs)
    119 return compiled_gm
File /mnt/xarfuse/uid-27458/0b399ed9-seed-nspid4026533690_cgpid4392124-ns-4026533687/torch/__init__.py:1655, in _TorchCompileWrapper.__call__(self, model_, inputs_)
   1654 def __call__(self, model_, inputs_):
-> 1655     return self.compiler_fn(model_, inputs_, **self.kwargs)
File /mnt/xarfuse/uid-27458/0b399ed9-seed-nspid4026533690_cgpid4392124-ns-4026533687/torch/_dynamo/backends/common.py:55, in aot_autograd.<locals>.compiler_fn(gm, example_inputs)
     52 try:
     53     # NB: NOT cloned!
     54     with enable_aot_logging(), patch_config:
---> 55         cg = aot_module_simplified(gm, example_inputs, **kwargs)
     56         counters["aot_autograd"]["ok"] += 1
     57         return disable(cg)
File /mnt/xarfuse/uid-27458/0b399ed9-seed-nspid4026533690_cgpid4392124-ns-4026533687/torch/_functorch/aot_autograd.py:4939, in aot_module_simplified(mod, args, fw_compiler, bw_compiler, partition_fn, decompositions, keep_inference_input_mutations, inference_compiler)
   4923 aot_config = AOTConfig(
   4924     fw_compiler=fw_compiler,
   4925     bw_compiler=bw_compiler,
   (...)
   4935     no_tangents=False,
   4936 )
   4938 with compiled_autograd.disable():
-> 4939     compiled_fn = create_aot_dispatcher_function(
   4940         functional_call,
   4941         full_args,
   4942         aot_config,
   4943     )
   4945 # TODO: There is something deeply wrong here; compiled_fn running with
   4946 # the boxed calling convention, but aot_module_simplified somehow
   4947 # historically returned a function that was not the boxed calling
   4948 # convention.  This should get fixed...
   4949 def forward(*runtime_args):
File /mnt/xarfuse/uid-27458/0b399ed9-seed-nspid4026533690_cgpid4392124-ns-4026533687/torch/_dynamo/utils.py:243, in dynamo_timed.<locals>.dynamo_timed_inner.<locals>.time_wrapper(*args, **kwargs)
    241 with torch.profiler.record_function(f"{key} (dynamo_timed)"):
    242     t0 = time.time()
--> 243     r = func(*args, **kwargs)
    244     time_spent = time.time() - t0
    245 compilation_time_metrics[key].append(time_spent)
File /mnt/xarfuse/uid-27458/0b399ed9-seed-nspid4026533690_cgpid4392124-ns-4026533687/torch/_functorch/aot_autograd.py:4386, in create_aot_dispatcher_function(flat_fn, flat_args, aot_config)
   4382 with enable_python_dispatcher():
   4383     # Patch set_rng_state as set_rng_state with fake tensors is
   4384     # nonsensical. This does not affect the collection of metadata.
   4385     with patch("torch.cuda.set_rng_state", lambda *args: None):
-> 4386         fw_metadata = run_functionalized_fw_and_collect_metadata(
   4387             flat_fn,
   4388             keep_input_mutations=aot_config.keep_inference_input_mutations,
   4389             is_train=needs_autograd,
   4390         )(*fake_flat_args)
   4392         req_subclass_dispatch = requires_subclass_dispatch(fake_flat_args, fw_metadata)
   4394         if needs_autograd and not any(x.requires_grad for x in fw_metadata.output_info):
   4395             # We realized that none of the outputs require grad,
   4396             # so we actually have an inference graph.
File /mnt/xarfuse/uid-27458/0b399ed9-seed-nspid4026533690_cgpid4392124-ns-4026533687/torch/_functorch/aot_autograd.py:1285, in run_functionalized_fw_and_collect_metadata.<locals>.inner(*flat_args)
   1282 is_result_of_custom_autograd_fn = False
   1283 if isinstance(o, torch.Tensor):
   1284     # Need to check for both custom cpp (CppFunction) and python (BackwardCFunction) autograd fns
-> 1285     if type(o.grad_fn).__name__ == "CppFunction":
   1286         is_result_of_custom_autograd_fn = True
   1287     if isinstance(o.grad_fn, torch.autograd.function.BackwardCFunction):
BackendCompilerFailed: backend='aot_eager' raised:
RuntimeError: A view was created in no_grad mode and its base or another view of its base has been modified inplace with grad mode enabled. Given that this use case is ambiguous and error-prone, it is forbidden. You can clarify your code by moving both the view and the inplace either both inside the no_grad block (if you don't want the inplace to be tracked) or both outside (if you want the inplace to be tracked).
Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information
You can suppress this exception and fall back to eager by setting:
    import torch._dynamo
    torch._dynamo.config.suppress_errors = True

Minified repro

import torch
from torch import Tensor

@torch.no_grad()
@torch.compile(backend='aot_eager')
def compute_in_place_add(param: Tensor):
    view_param = param[:]
    view_param.add_(1.0) 
    return view_param
    
a = torch.tensor([1.0], requires_grad=True)
for _ in range(2):
    compute_in_place_add(a)

Versions

Collecting environment information...
PyTorch version: 2.2.0.dev20231115
Is debug build: False
CUDA used to build PyTorch: None
ROCM used to build PyTorch: N/A
OS: macOS 14.0 (arm64)
GCC version: Could not collect
Clang version: 14.0.3 (clang-1403.0.22.14.1)
CMake version: Could not collect
Libc version: N/A
Python version: 3.10.13 (main, Sep 11 2023, 08:16:02) [Clang 14.0.6 ] (64-bit runtime)
Python platform: macOS-14.0-arm64-arm-64bit
Is CUDA available: False
CUDA runtime version: No CUDA
CUDA_MODULE_LOADING set to: N/A
GPU models and configuration: No CUDA
Nvidia driver version: No CUDA
cuDNN version: No CUDA
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True
CPU:
Apple M1 Pro
Versions of relevant libraries:
[pip3] numpy==1.26.0
[pip3] torch==2.2.0.dev20231115
[pip3] torchaudio==2.2.0.dev20231115
[pip3] torchvision==0.17.0.dev20231115
[conda] numpy                     1.26.0          py310h3b2db8e_0
[conda] numpy-base                1.26.0          py310ha9811e2_0
[conda] pytorch                   2.2.0.dev20231115        py3.10_0    pytorch-nightly
[conda] torchaudio                2.2.0.dev20231115       py310_cpu    pytorch-nightly
[conda] torchvision               0.17.0.dev20231115       py310_cpu    pytorch-nightly

cc @ezyang @msaroufim @bdhirsh @anijain2305 @zou3519 @chauhang @wconstab

@bdhirsh
Copy link
Contributor

bdhirsh commented Nov 15, 2023

We technically error out in PT2 here while eager mode doesn't raise an error - but this eager code is pretty broken too. If I print the .grad_fn of the output in eager mode, I get the same error:

import torch
from torch import Tensor

@torch.no_grad()
def compute_in_place_add(param: Tensor):
    view_param = param[:]
    view_param.add_(1.0)
    return view_param

a = torch.tensor([1.0], requires_grad=True)
for _ in range(2):
    out = compute_in_place_add(a)
    # prints "RuntimeError: A view was created in no_grad mode and its base or another view of its base has been modified inplace with grad mode enabled. Given that this use case is ambiguous and error-prone, it is forbidden. You can clarify your code by moving both the view and the inplace either both inside the no_grad block (if you don't want the inplace to be tracked) or both outside (if you want the inplace to be tracked)."
    print(out.grad_fn)

We might be able to avoid this error in PT2 if we change all places where we access .grad_fn in the compiler to be some "safe" version that promises not to error if the tensor's grad_fn is invalid.

But I think you can change your code to avoid the error (both in eager and in PT2) by using .detach():

import torch
from torch import Tensor

@torch.no_grad()
def compute_in_place_add(param: Tensor):
    view_param = param[:]
    view_param.add_(1.0)
    return view_param

a = torch.tensor([1.0], requires_grad=True)
for _ in range(2):
    compute_in_place_add(a.detach())

@yanboliang yanboliang added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Nov 17, 2023
@yf225 yf225 added the module: aotdispatch umbrella label for AOTAutograd issues label Mar 29, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
module: aotdispatch umbrella label for AOTAutograd issues oncall: pt2 triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

No branches or pull requests

4 participants