Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

torch.histc does not work with torch.compile(dynamic=True) #124512

Closed
nopperl opened this issue Apr 19, 2024 · 4 comments
Closed

torch.histc does not work with torch.compile(dynamic=True) #124512

nopperl opened this issue Apr 19, 2024 · 4 comments
Labels
good first issue module: dynamic shapes oncall: pt2 triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@nopperl
Copy link
Contributor

nopperl commented Apr 19, 2024

馃悰 Describe the bug

torch.histc does not work with torch.compile(dynamic=True)

Using torch.histc with torch.compile(dynamic=True) raises a NotImplementedError. See the minified repro. The code works fine without dynamic=True.

Following the link in the error message, it seems like the torch.ops.aten.histc.default op does not support dynamic shapes. According to opcheck:

import torch
from torch.testing._internal.optests import opcheck

inputs = torch.rand(3, device="cuda")
opcheck(torch.ops.aten.histc.default, args=(inputs,), kwargs={"bins": 4, "min": 0, "max": 1})

Error message:

...
OpCheckError: opcheck(op, ...): test_aot_dispatch_dynamic failed with aten.histc.default
...

(+ plus the same stack trace as before)

Note: this is also the case for torch.ops.aten.histc.out.

Error logs

---------------------------------------------------------------------------
NotImplementedError                       Traceback (most recent call last)
File /scratch/pytorch/torch/_subclasses/fake_tensor.py:1432, in FakeTensorMode._dispatch_impl(self, func, types, args, kwargs)
   1431     with in_kernel_invocation_manager(self):
-> 1432         r = func(*args, **kwargs)
   1433 except NotImplementedError as not_implemented_error:

File /scratch/pytorch/torch/_ops.py:629, in OpOverload.__call__(self_, *args, **kwargs)
    626 def __call__(self_, *args, **kwargs):  # noqa: B902
    627     # use `self_` to avoid naming collide with aten ops arguments that
    628     # are named "self". This way, all the aten ops can be called by kwargs.
--> 629     return self_._op(*args, **kwargs)

NotImplementedError: aten::histc: attempted to run this operator with Meta tensors, but there was no fake impl or Meta kernel registered. You may have run into this message while using an operator with PT2 compilation APIs (torch.compile/torch.export); in order to use this operator with those APIs you'll need to add a fake impl.Please see the following doc for next steps: https://docs.google.com/document/d/1_W62p8WJOQQUzPsJYa7s701JXt0qf2OfLub2sbkHOaU/edit

During handling of the above exception, another exception occurred:

UnsupportedOperatorException              Traceback (most recent call last)
File /scratch/pytorch/torch/_dynamo/utils.py:1865, in run_node(tracer, node, args, kwargs, nnmodule)
   1864 if op == "call_function":
-> 1865     return node.target(*args, **kwargs)
   1866 elif op == "call_method":

File /scratch/pytorch/torch/utils/_stats.py:20, in count.<locals>.wrapper(*args, **kwargs)
     19 simple_call_counter[fn.__qualname__] = simple_call_counter[fn.__qualname__] + 1
---> 20 return fn(*args, **kwargs)

File /scratch/pytorch/torch/_subclasses/fake_tensor.py:870, in FakeTensorMode.__torch_dispatch__(self, func, types, args, kwargs)
    869 try:
--> 870     return self.dispatch(func, types, args, kwargs)
    871 except TypeError:

File /scratch/pytorch/torch/_subclasses/fake_tensor.py:1215, in FakeTensorMode.dispatch(self, func, types, args, kwargs)
   1214 if self.cache_enabled:
-> 1215     return self._cached_dispatch_impl(func, types, args, kwargs)
   1216 else:

File /scratch/pytorch/torch/_subclasses/fake_tensor.py:948, in FakeTensorMode._cached_dispatch_impl(self, func, types, args, kwargs)
    947 if output is unassigned:
--> 948     output = self._dispatch_impl(func, types, args, kwargs)
    950 return output

File /scratch/pytorch/torch/_subclasses/fake_tensor.py:1434, in FakeTensorMode._dispatch_impl(self, func, types, args, kwargs)
   1433 except NotImplementedError as not_implemented_error:
-> 1434     return maybe_run_unsafe_fallback(not_implemented_error)
   1436 return self.wrap_meta_outputs_with_default_device_logic(
   1437     r, func, flat_args, device=kwargs.get("device")
   1438 )

File /scratch/pytorch/torch/_subclasses/fake_tensor.py:1417, in FakeTensorMode._dispatch_impl.<locals>.maybe_run_unsafe_fallback(error)
   1416 if has_symbolic_sizes or not self.can_run_unsafe_fallback(func):
-> 1417     raise UnsupportedOperatorException(func)
   1418 if error is None:

UnsupportedOperatorException: aten.histc.default

The above exception was the direct cause of the following exception:

RuntimeError                              Traceback (most recent call last)
File /scratch/pytorch/torch/_dynamo/utils.py:1747, in get_fake_value(node, tx, allow_non_graph_fake)
   1746     with tx.fake_mode, enable_python_dispatcher():
-> 1747         ret_val = wrap_fake_exception(
   1748             lambda: run_node(tx.output, node, args, kwargs, nnmodule)
   1749         )
   1750 except Unsupported:

File /scratch/pytorch/torch/_dynamo/utils.py:1262, in wrap_fake_exception(fn)
   1261 try:
-> 1262     return fn()
   1263 except UnsupportedFakeTensorException as e:

File /scratch/pytorch/torch/_dynamo/utils.py:1748, in get_fake_value.<locals>.<lambda>()
   1746     with tx.fake_mode, enable_python_dispatcher():
   1747         ret_val = wrap_fake_exception(
-> 1748             lambda: run_node(tx.output, node, args, kwargs, nnmodule)
   1749         )
   1750 except Unsupported:

File /scratch/pytorch/torch/_dynamo/utils.py:1883, in run_node(tracer, node, args, kwargs, nnmodule)
   1882     except Exception as e:
-> 1883         raise RuntimeError(make_error_message(e)).with_traceback(
   1884             e.__traceback__
   1885         ) from e
   1887 raise AssertionError(op)

File /scratch/pytorch/torch/_dynamo/utils.py:1865, in run_node(tracer, node, args, kwargs, nnmodule)
   1864 if op == "call_function":
-> 1865     return node.target(*args, **kwargs)
   1866 elif op == "call_method":

File /scratch/pytorch/torch/utils/_stats.py:20, in count.<locals>.wrapper(*args, **kwargs)
     19 simple_call_counter[fn.__qualname__] = simple_call_counter[fn.__qualname__] + 1
---> 20 return fn(*args, **kwargs)

File /scratch/pytorch/torch/_subclasses/fake_tensor.py:870, in FakeTensorMode.__torch_dispatch__(self, func, types, args, kwargs)
    869 try:
--> 870     return self.dispatch(func, types, args, kwargs)
    871 except TypeError:

File /scratch/pytorch/torch/_subclasses/fake_tensor.py:1215, in FakeTensorMode.dispatch(self, func, types, args, kwargs)
   1214 if self.cache_enabled:
-> 1215     return self._cached_dispatch_impl(func, types, args, kwargs)
   1216 else:

File /scratch/pytorch/torch/_subclasses/fake_tensor.py:948, in FakeTensorMode._cached_dispatch_impl(self, func, types, args, kwargs)
    947 if output is unassigned:
--> 948     output = self._dispatch_impl(func, types, args, kwargs)
    950 return output

File /scratch/pytorch/torch/_subclasses/fake_tensor.py:1434, in FakeTensorMode._dispatch_impl(self, func, types, args, kwargs)
   1433 except NotImplementedError as not_implemented_error:
-> 1434     return maybe_run_unsafe_fallback(not_implemented_error)
   1436 return self.wrap_meta_outputs_with_default_device_logic(
   1437     r, func, flat_args, device=kwargs.get("device")
   1438 )

File /scratch/pytorch/torch/_subclasses/fake_tensor.py:1417, in FakeTensorMode._dispatch_impl.<locals>.maybe_run_unsafe_fallback(error)
   1416 if has_symbolic_sizes or not self.can_run_unsafe_fallback(func):
-> 1417     raise UnsupportedOperatorException(func)
   1418 if error is None:

RuntimeError: Failed running call_function <built-in method histc of type object at 0x7f7546500a20>(*(FakeTensor(..., device='cuda:0', size=(s0,)),), **{'bins': 4, 'min': 0, 'max': 1}):
aten.histc.default

During handling of the above exception, another exception occurred:

Unsupported                               Traceback (most recent call last)
      3 inputs = torch.rand(3, device="cuda")
      4 histc_opt = torch.compile(torch.histc, dynamic=True, fullgraph=True)
----> 5 histc_opt(inputs, bins=4, min=0, max=1)

File /scratch/pytorch/torch/_dynamo/eval_frame.py:403, in _TorchDynamoContext.__call__.<locals>._fn(*args, **kwargs)
    401 prior = set_eval_frame(callback)
    402 try:
--> 403     return fn(*args, **kwargs)
    404 finally:
    405     set_eval_frame(prior)

File /scratch/pytorch/torch/_dynamo/convert_frame.py:977, in catch_errors_wrapper.<locals>.catch_errors(frame, cache_entry, frame_state)
    973             return hijacked_callback(frame, cache_entry, hooks, frame_state)
    975 with compile_lock, _disable_current_modes():
    976     # skip=1: skip this frame
--> 977     return callback(frame, cache_entry, hooks, frame_state, skip=1)

File /scratch/pytorch/torch/_dynamo/convert_frame.py:411, in convert_frame_assert.<locals>._convert_frame_assert(frame, cache_entry, hooks, frame_state, skip)
    397 compile_id = CompileId(frame_id, frame_compile_id)
    399 signpost_event(
    400     "dynamo",
    401     "_convert_frame_assert._compile",
   (...)
    408     },
    409 )
--> 411 return _compile(
    412     frame.f_code,
    413     frame.f_globals,
    414     frame.f_locals,
    415     frame.f_builtins,
    416     compiler_fn,
    417     one_graph,
    418     export,
    419     export_constraints,
    420     hooks,
    421     cache_size,
    422     frame,
    423     frame_state=frame_state,
    424     compile_id=compile_id,
    425     skip=skip + 1,
    426 )

File /scratch/pytorch/torch/_utils_internal.py:70, in compiletime_sl_profile_meta.<locals>.compiletime_sl_profile_inner.<locals>.wrapper_function(*args, **kwargs)
     68 @functools.wraps(function)
     69 def wrapper_function(*args, **kwargs):
---> 70     return function(*args, **kwargs)

File /scratch/.conda/envs/pytorch/lib/python3.11/contextlib.py:81, in ContextDecorator.__call__.<locals>.inner(*args, **kwds)
     78 @wraps(func)
     79 def inner(*args, **kwds):
     80     with self._recreate_cm():
---> 81         return func(*args, **kwds)

File /scratch/pytorch/torch/_dynamo/convert_frame.py:700, in _compile(code, globals, locals, builtins, compiler_fn, one_graph, export, export_constraints, hooks, cache_size, frame, frame_state, compile_id, skip)
    698 fail_user_frame_lineno: Optional[int] = None
    699 try:
--> 700     guarded_code = compile_inner(code, one_graph, hooks, transform)
    701     return guarded_code
    702 except (
    703     Unsupported,
    704     TorchRuntimeError,
   (...)
    711     BisectValidationException,
    712 ) as e:

File /scratch/pytorch/torch/_dynamo/utils.py:267, in dynamo_timed.<locals>.dynamo_timed_inner.<locals>.time_wrapper(*args, **kwargs)
    265 with torch.profiler.record_function(f"{key} (dynamo_timed)"):
    266     t0 = time.time()
--> 267     r = func(*args, **kwargs)
    268     time_spent = time.time() - t0
    269 compilation_time_metrics[key].append(time_spent)

File /scratch/pytorch/torch/_dynamo/convert_frame.py:568, in _compile.<locals>.compile_inner(code, one_graph, hooks, transform)
    566 CompileContext.get().attempt = attempt
    567 try:
--> 568     out_code = transform_code_object(code, transform)
    569     break
    570 except exc.RestartAnalysis as e:

File /scratch/pytorch/torch/_dynamo/bytecode_transformation.py:1116, in transform_code_object(code, transformations, safe)
   1113 instructions = cleaned_instructions(code, safe)
   1114 propagate_line_nums(instructions)
-> 1116 transformations(instructions, code_options)
   1117 return clean_and_assemble_instructions(instructions, keys, code_options)[1]

File /scratch/pytorch/torch/_dynamo/convert_frame.py:173, in preserve_global_state.<locals>._fn(*args, **kwargs)
    171 cleanup = setup_compile_debug()
    172 try:
--> 173     return fn(*args, **kwargs)
    174 finally:
    175     cleanup.close()

File /scratch/pytorch/torch/_dynamo/convert_frame.py:515, in _compile.<locals>.transform(instructions, code_options)
    513 try:
    514     with tracing(tracer.output.tracing_context), tracer.set_current_tx():
--> 515         tracer.run()
    516 except exc.UnspecializeRestartAnalysis:
    517     speculation_log.clear()

File /scratch/pytorch/torch/_dynamo/symbolic_convert.py:2237, in InstructionTranslator.run(self)
   2236 def run(self):
-> 2237     super().run()

File /scratch/pytorch/torch/_dynamo/symbolic_convert.py:875, in InstructionTranslatorBase.run(self)
    873 try:
    874     self.output.push_tx(self)
--> 875     while self.step():
    876         pass
    877 except BackendCompilerFailed:

File /scratch/pytorch/torch/_dynamo/symbolic_convert.py:790, in InstructionTranslatorBase.step(self)
    787 self.update_block_stack(inst)
    789 try:
--> 790     self.dispatch_table[inst.opcode](self, inst)
    791     return not self.output.should_exit
    792 except ReturnValueOp:

File /scratch/pytorch/torch/_dynamo/symbolic_convert.py:492, in break_graph_if_unsupported.<locals>.decorator.<locals>.wrapper(self, inst)
    490     return handle_graph_break(self, inst, speculation.reason)
    491 try:
--> 492     return inner_fn(self, inst)
    493 except Unsupported as excp:
    494     if self.generic_context_manager_depth > 0:
    495         # We don't support graph break under GenericContextWrappingVariable,
    496         # If there is, we roll back to the checkpoint and fall back.

File /scratch/pytorch/torch/_dynamo/symbolic_convert.py:1301, in InstructionTranslatorBase.CALL_FUNCTION_EX(self, inst)
   1299 # Map to a dictionary of str -> VariableTracker
   1300 kwargsvars = kwargsvars.keys_as_python_constant()
-> 1301 self.call_function(fn, argsvars.items, kwargsvars)

File /scratch/pytorch/torch/_dynamo/symbolic_convert.py:730, in InstructionTranslatorBase.call_function(self, fn, args, kwargs)
    728 if inner_fn and callable(inner_fn) and is_forbidden(inner_fn):
    729     raise AssertionError(f"Attempt to trace forbidden callable {inner_fn}")
--> 730 self.push(fn.call_function(self, args, kwargs))

File /scratch/pytorch/torch/_dynamo/variables/lazy.py:131, in _create_realize_and_forward.<locals>.realize_and_forward(self, *args, **kwargs)
    129 @functools.wraps(getattr(VariableTracker, name))
    130 def realize_and_forward(self, *args, **kwargs):
--> 131     return getattr(self.realize(), name)(*args, **kwargs)

File /scratch/pytorch/torch/_dynamo/variables/torch.py:754, in TorchInGraphFunctionVariable.call_function(self, tx, args, kwargs)
    749                 if getattr(self.value, "__module__", None) == "math" and hasattr(
    750                     torch, torch_sym_op
    751                 ):
    752                     fn_ = getattr(torch, torch_sym_op)
--> 754             tensor_variable = wrap_fx_proxy(
    755                 tx=tx,
    756                 proxy=tx.output.create_proxy(
    757                     "call_function",
    758                     fn_,
    759                     *proxy_args_kwargs(args, kwargs),
    760                 ),
    761             )
    763             if (
    764                 isinstance(tensor_variable, TensorVariable)
    765                 and "requires_grad" in kwargs
    766                 and kwargs["requires_grad"].as_python_constant()
    767             ):
    768                 unimplemented(
    769                     """factory functions that return tensors that require grad are not supported.
    770 Either create the tensor outside the compiled region, or do not set the tensor to require_grad"""
    771                 )

File /scratch/pytorch/torch/_dynamo/variables/builder.py:1435, in wrap_fx_proxy(tx, proxy, example_value, subclass_type, **options)
   1427 kwargs = {
   1428     "tx": tx,
   1429     "proxy": proxy,
   (...)
   1432     **options,
   1433 }
   1434 if subclass_type is None:
-> 1435     return wrap_fx_proxy_cls(target_cls=TensorVariable, **kwargs)
   1436 else:
   1437     result = wrap_fx_proxy_cls(target_cls=TensorWithTFOverrideVariable, **kwargs)

File /scratch/pytorch/torch/_dynamo/variables/builder.py:1520, in wrap_fx_proxy_cls(target_cls, tx, proxy, example_value, subclass_type, **options)
   1516 # with preserve_rng_state():
   1517 if example_value is None:
   1518     # only allow_non_graph_fake in this instance because we handle the non-fake
   1519     # cases properly below.
-> 1520     example_value = get_fake_value(proxy.node, tx, allow_non_graph_fake=True)
   1522 # Handle recursive calls here
   1523 elif maybe_get_fake_mode(example_value) is tx.fake_mode:

File /scratch/pytorch/torch/_dynamo/utils.py:1794, in get_fake_value(node, tx, allow_non_graph_fake)
   1788             module, ctx = maybe_pystub
   1789             import_suggestion = (
   1790                 f"It's possible that the support was implemented in "
   1791                 f"module `{module}` and you may need to `import {module}`"
   1792                 f"({ctx}), otherwise "
   1793             )
-> 1794     unimplemented(
   1795         f"unsupported operator: {cause.func} ({import_suggestion}see "
   1796         "https://docs.google.com/document/d/1GgvOe7C8_NVOMLOCwDaYV1mXXyHMXY7ExoewHqooxrs/edit#heading=h.64r4npvq0w0"
   1797         " for how to fix)"
   1798     )
   1799 elif isinstance(
   1800     cause, torch.fx.experimental.symbolic_shapes.GuardOnDataDependentSymNode
   1801 ):
   1802     raise UserError(  # noqa: TRY200
   1803         UserErrorType.CONSTRAINT_VIOLATION,
   1804         "Tried to use data-dependent value in the subsequent computation. "
   (...)
   1808         case_name="constrain_as_size_example",
   1809     )

File /scratch/pytorch/torch/_dynamo/exc.py:212, in unimplemented(msg, from_exc)
    210 if from_exc is not _NOTHING:
    211     raise Unsupported(msg) from from_exc
--> 212 raise Unsupported(msg)

Unsupported: unsupported operator: aten.histc.default (see https://docs.google.com/document/d/1GgvOe7C8_NVOMLOCwDaYV1mXXyHMXY7ExoewHqooxrs/edit#heading=h.64r4npvq0w0 for how to fix)

from user code:
   File "/scratch/pytorch/torch/_dynamo/external_utils.py", line 36, in inner
    return fn(*args, **kwargs)

Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information


You can suppress this exception and fall back to eager by setting:
    import torch._dynamo
    torch._dynamo.config.suppress_errors = True

Minified repro

import torch

inputs = torch.rand(3, device="cuda")
histc_opt = torch.compile(torch.histc, dynamic=True, fullgraph=True)
histc_opt(inputs, bins=4, min=0, max=1)

Versions

Versions of relevant libraries:
[pip3] numpy==1.26.4
[pip3] optree==0.11.0
[pip3] torch==2.4.0a0+git87f651c
[pip3] triton==3.0.0

cc @ezyang @msaroufim @bdhirsh @anijain2305 @chauhang

@ezyang
Copy link
Contributor

ezyang commented Apr 19, 2024

Add a meta following https://docs.google.com/document/d/1GgvOe7C8_NVOMLOCwDaYV1mXXyHMXY7ExoewHqooxrs/edit shouldn't be too bad

@nopperl
Copy link
Contributor Author

nopperl commented Apr 20, 2024

@ezyang thanks for the pointer. I've added a meta function (#124548) and the above code works now.

@vunnamkowsik
Copy link

Is this issue still open? @nopperl

@nopperl
Copy link
Contributor Author

nopperl commented Apr 22, 2024

@vunnamkowsik it will be fixed once #124548 is merged.

@masnesral masnesral added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Apr 22, 2024
carmocca pushed a commit to carmocca/pytorch that referenced this issue Apr 29, 2024
Registers a meta function for the `aten.histc.default` and `aten.histc.out` ops to support `torch.compile(dynamic=True)`. Fixes pytorch#124512.

Pull Request resolved: pytorch#124548
Approved by: https://github.com/lezcano, https://github.com/peterbell10
andoorve pushed a commit to andoorve/pytorch that referenced this issue May 1, 2024
Registers a meta function for the `aten.histc.default` and `aten.histc.out` ops to support `torch.compile(dynamic=True)`. Fixes pytorch#124512.

Pull Request resolved: pytorch#124548
Approved by: https://github.com/lezcano, https://github.com/peterbell10
pytorch-bot bot pushed a commit that referenced this issue May 3, 2024
Registers a meta function for the `aten.histc.default` and `aten.histc.out` ops to support `torch.compile(dynamic=True)`. Fixes #124512.

Pull Request resolved: #124548
Approved by: https://github.com/lezcano, https://github.com/peterbell10
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
good first issue module: dynamic shapes oncall: pt2 triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

Successfully merging a pull request may close this issue.

4 participants