Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

topk, bmm, max ops - Multiple dispatch failed for 'torch.ops.aten.size'; all __torch_dispatch__ handlers returned NotImplemented: #122772

Closed
rahulvijayaraghavan opened this issue Mar 27, 2024 · 2 comments
Labels
module: dynamic shapes module: fakeTensor module: pt2-dispatcher PT2 dispatcher-related issues (e.g., aotdispatch, functionalization, faketensor, custom-op, oncall: pt2 triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@rahulvijayaraghavan
Copy link

rahulvijayaraghavan commented Mar 27, 2024

🐛 Describe the bug

Following error is observed for out variants of topk, bmm and max ops:
Multiple dispatch failed for 'torch.ops.aten.size'; all torch_dispatch handlers returned NotImplemented:

Error logs

E0327 08:40:40.192000 137790200328192 torch/_subclasses/fake_tensor.py:864] [0/1] fake tensor raised TypeError
E0327 08:40:40.192000 137790200328192 torch/_subclasses/fake_tensor.py:864] [0/1] Traceback (most recent call last):
E0327 08:40:40.192000 137790200328192 torch/_subclasses/fake_tensor.py:864] [0/1]   File "/usr/local/lib/python3.10/dist-packages/torch/_subclasses/fake_tensor.py", line 862, in __torch_dispatch__
E0327 08:40:40.192000 137790200328192 torch/_subclasses/fake_tensor.py:864] [0/1]     return self.dispatch(func, types, args, kwargs)
E0327 08:40:40.192000 137790200328192 torch/_subclasses/fake_tensor.py:864] [0/1]   File "/usr/local/lib/python3.10/dist-packages/torch/_subclasses/fake_tensor.py", line 1207, in dispatch
E0327 08:40:40.192000 137790200328192 torch/_subclasses/fake_tensor.py:864] [0/1]     return self._cached_dispatch_impl(func, types, args, kwargs)
E0327 08:40:40.192000 137790200328192 torch/_subclasses/fake_tensor.py:864] [0/1]   File "/usr/local/lib/python3.10/dist-packages/torch/_subclasses/fake_tensor.py", line 940, in _cached_dispatch_impl
E0327 08:40:40.192000 137790200328192 torch/_subclasses/fake_tensor.py:864] [0/1]     output = self._dispatch_impl(func, types, args, kwargs)
E0327 08:40:40.192000 137790200328192 torch/_subclasses/fake_tensor.py:864] [0/1]   File "/usr/local/lib/python3.10/dist-packages/torch/_subclasses/fake_tensor.py", line 1424, in _dispatch_impl
E0327 08:40:40.192000 137790200328192 torch/_subclasses/fake_tensor.py:864] [0/1]     r = func(*args, **kwargs)
E0327 08:40:40.192000 137790200328192 torch/_subclasses/fake_tensor.py:864] [0/1]   File "/usr/local/lib/python3.10/dist-packages/torch/_ops.py", line 600, in __call__
E0327 08:40:40.192000 137790200328192 torch/_subclasses/fake_tensor.py:864] [0/1]     return self_._op(*args, **kwargs)
E0327 08:40:40.192000 137790200328192 torch/_subclasses/fake_tensor.py:864] [0/1]   File "/usr/local/lib/python3.10/dist-packages/torch/utils/_stats.py", line 20, in wrapper
E0327 08:40:40.192000 137790200328192 torch/_subclasses/fake_tensor.py:864] [0/1]     return fn(*args, **kwargs)
E0327 08:40:40.192000 137790200328192 torch/_subclasses/fake_tensor.py:864] [0/1]   File "/usr/local/lib/python3.10/dist-packages/torch/_subclasses/fake_tensor.py", line 547, in __torch_dispatch__
E0327 08:40:40.192000 137790200328192 torch/_subclasses/fake_tensor.py:864] [0/1]     return func(*args, **kwargs)
E0327 08:40:40.192000 137790200328192 torch/_subclasses/fake_tensor.py:864] [0/1]   File "/usr/local/lib/python3.10/dist-packages/torch/_ops.py", line 600, in __call__
E0327 08:40:40.192000 137790200328192 torch/_subclasses/fake_tensor.py:864] [0/1]     return self_._op(*args, **kwargs)
E0327 08:40:40.192000 137790200328192 torch/_subclasses/fake_tensor.py:864] [0/1] TypeError: Multiple dispatch failed for 'torch.ops.aten.size'; all __torch_dispatch__ handlers returned NotImplemented:
E0327 08:40:40.192000 137790200328192 torch/_subclasses/fake_tensor.py:864] [0/1] 
E0327 08:40:40.192000 137790200328192 torch/_subclasses/fake_tensor.py:864] [0/1]   - tensor subclass <class 'torch._subclasses.fake_tensor.FakeTensor'>
E0327 08:40:40.192000 137790200328192 torch/_subclasses/fake_tensor.py:864] [0/1] 
E0327 08:40:40.192000 137790200328192 torch/_subclasses/fake_tensor.py:864] [0/1] For more information, try re-running with TORCH_LOGS=not_implemented
---------------------------------------------------------------------------
TorchRuntimeError                         Traceback (most recent call last)
[<ipython-input-3-db886c9eeb8c>](https://localhost:8080/#) in <cell line: 12>()
     10 
     11 x = torch.arange(1., 8.)
---> 12 opt_model(x, 3, out=(values, indices))
     13 # Multiple dispatch failed for 'torch.ops.aten.size'; all __torch_dispatch__ handlers returned NotImplemented:

34 frames
[/usr/local/lib/python3.10/dist-packages/torch/_dynamo/eval_frame.py](https://localhost:8080/#) in _fn(*args, **kwargs)
    407             prior = set_eval_frame(callback)
    408             try:
--> 409                 return fn(*args, **kwargs)
    410             finally:
    411                 set_eval_frame(prior)

[/usr/local/lib/python3.10/dist-packages/torch/_dynamo/convert_frame.py](https://localhost:8080/#) in catch_errors(frame, cache_entry, frame_state)
    937         with compile_lock, _disable_current_modes():
    938             # skip=1: skip this frame
--> 939             return callback(frame, cache_entry, hooks, frame_state, skip=1)
    940 
    941     catch_errors._torchdynamo_orig_callable = callback  # type: ignore[attr-defined]

[/usr/local/lib/python3.10/dist-packages/torch/_dynamo/convert_frame.py](https://localhost:8080/#) in _convert_frame(frame, cache_entry, hooks, frame_state, skip)
    800         counters["frames"]["total"] += 1
    801         try:
--> 802             result = inner_convert(
    803                 frame, cache_entry, hooks, frame_state, skip=skip + 1
    804             )

[/usr/local/lib/python3.10/dist-packages/torch/_dynamo/convert_frame.py](https://localhost:8080/#) in _convert_frame_assert(frame, cache_entry, hooks, frame_state, skip)
    398         )
    399 
--> 400         return _compile(
    401             frame.f_code,
    402             frame.f_globals,

[/usr/lib/python3.10/contextlib.py](https://localhost:8080/#) in inner(*args, **kwds)
     77         def inner(*args, **kwds):
     78             with self._recreate_cm():
---> 79                 return func(*args, **kwds)
     80         return inner
     81 

[/usr/local/lib/python3.10/dist-packages/torch/_dynamo/convert_frame.py](https://localhost:8080/#) in _compile(code, globals, locals, builtins, compiler_fn, one_graph, export, export_constraints, hooks, cache_size, frame, frame_state, compile_id, skip)
    684         fail_user_frame_lineno: Optional[int] = None
    685         try:
--> 686             guarded_code = compile_inner(code, one_graph, hooks, transform)
    687             return guarded_code
    688         except (

[/usr/local/lib/python3.10/dist-packages/torch/_dynamo/utils.py](https://localhost:8080/#) in time_wrapper(*args, **kwargs)
    263             with torch.profiler.record_function(f"{key} (dynamo_timed)"):
    264                 t0 = time.time()
--> 265                 r = func(*args, **kwargs)
    266                 time_spent = time.time() - t0
    267             compilation_time_metrics[key].append(time_spent)

[/usr/local/lib/python3.10/dist-packages/torch/_dynamo/convert_frame.py](https://localhost:8080/#) in compile_inner(code, one_graph, hooks, transform)
    539             CompileContext.get().attempt = attempt
    540             try:
--> 541                 out_code = transform_code_object(code, transform)
    542                 break
    543             except exc.RestartAnalysis as e:

[/usr/local/lib/python3.10/dist-packages/torch/_dynamo/bytecode_transformation.py](https://localhost:8080/#) in transform_code_object(code, transformations, safe)
   1034     propagate_line_nums(instructions)
   1035 
-> 1036     transformations(instructions, code_options)
   1037     return clean_and_assemble_instructions(instructions, keys, code_options)[1]
   1038 

[/usr/local/lib/python3.10/dist-packages/torch/_dynamo/convert_frame.py](https://localhost:8080/#) in _fn(*args, **kwargs)
    163         cleanup = setup_compile_debug()
    164         try:
--> 165             return fn(*args, **kwargs)
    166         finally:
    167             cleanup.close()

[/usr/local/lib/python3.10/dist-packages/torch/_dynamo/convert_frame.py](https://localhost:8080/#) in transform(instructions, code_options)
    501         try:
    502             with tracing(tracer.output.tracing_context), tracer.set_current_tx():
--> 503                 tracer.run()
    504         except exc.UnspecializeRestartAnalysis:
    505             speculation_log.clear()

[/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py](https://localhost:8080/#) in run(self)
   2114 
   2115     def run(self):
-> 2116         super().run()
   2117 
   2118     def match_nested_cell(self, name, cell):

[/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py](https://localhost:8080/#) in run(self)
    837             try:
    838                 self.output.push_tx(self)
--> 839                 while self.step():
    840                     pass
    841             except BackendCompilerFailed:

[/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py](https://localhost:8080/#) in step(self)
    751 
    752         try:
--> 753             self.dispatch_table[inst.opcode](self, inst)
    754             return not self.output.should_exit
    755         except ReturnValueOp:

[/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py](https://localhost:8080/#) in wrapper(self, inst)
    479                 return handle_graph_break(self, inst, speculation.reason)
    480             try:
--> 481                 return inner_fn(self, inst)
    482             except Unsupported as excp:
    483                 if self.generic_context_manager_depth > 0:

[/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py](https://localhost:8080/#) in CALL_FUNCTION_KW(self, inst)
   1253         kwargs = dict(zip(argnames, kwargs_list))
   1254         assert len(kwargs) == len(argnames)
-> 1255         self.call_function(fn, args, kwargs)
   1256 
   1257     def LOAD_METHOD_SUPER(self, inst):

[/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py](https://localhost:8080/#) in call_function(self, fn, args, kwargs)
    691         if inner_fn and callable(inner_fn) and is_forbidden(inner_fn):
    692             raise AssertionError(f"Attempt to trace forbidden callable {inner_fn}")
--> 693         self.push(fn.call_function(self, args, kwargs))
    694 
    695     def inline_user_function_return(self, fn, args, kwargs):

[/usr/local/lib/python3.10/dist-packages/torch/_dynamo/variables/torch.py](https://localhost:8080/#) in call_function(self, tx, args, kwargs)
    743                     fn_ = getattr(torch, torch_sym_op)
    744 
--> 745             tensor_variable = wrap_fx_proxy(
    746                 tx=tx,
    747                 proxy=tx.output.create_proxy(

[/usr/local/lib/python3.10/dist-packages/torch/_dynamo/variables/builder.py](https://localhost:8080/#) in wrap_fx_proxy(tx, proxy, example_value, subclass_type, **options)
   1339     }
   1340     if subclass_type is None:
-> 1341         return wrap_fx_proxy_cls(target_cls=TensorVariable, **kwargs)
   1342     else:
   1343         result = wrap_fx_proxy_cls(target_cls=TensorWithTFOverrideVariable, **kwargs)

[/usr/local/lib/python3.10/dist-packages/torch/_dynamo/variables/builder.py](https://localhost:8080/#) in wrap_fx_proxy_cls(target_cls, tx, proxy, example_value, subclass_type, **options)
   1424         # only allow_non_graph_fake in this instance because we handle the non-fake
   1425         # cases properly below.
-> 1426         example_value = get_fake_value(proxy.node, tx, allow_non_graph_fake=True)
   1427 
   1428     # Handle recursive calls here

[/usr/local/lib/python3.10/dist-packages/torch/_dynamo/utils.py](https://localhost:8080/#) in get_fake_value(node, tx, allow_non_graph_fake)
   1744             unimplemented(f"TypeError {node.target}: {cause}")
   1745 
-> 1746         raise TorchRuntimeError(str(e)).with_traceback(e.__traceback__) from None
   1747 
   1748     if not allow_non_graph_fake:

[/usr/local/lib/python3.10/dist-packages/torch/_dynamo/utils.py](https://localhost:8080/#) in get_fake_value(node, tx, allow_non_graph_fake)
   1676     try:
   1677         with tx.fake_mode, enable_python_dispatcher():
-> 1678             ret_val = wrap_fake_exception(
   1679                 lambda: run_node(tx.output, node, args, kwargs, nnmodule)
   1680             )

[/usr/local/lib/python3.10/dist-packages/torch/_dynamo/utils.py](https://localhost:8080/#) in wrap_fake_exception(fn)
   1207 def wrap_fake_exception(fn):
   1208     try:
-> 1209         return fn()
   1210     except UnsupportedFakeTensorException as e:
   1211         from .exc import unimplemented

[/usr/local/lib/python3.10/dist-packages/torch/_dynamo/utils.py](https://localhost:8080/#) in <lambda>()
   1677         with tx.fake_mode, enable_python_dispatcher():
   1678             ret_val = wrap_fake_exception(
-> 1679                 lambda: run_node(tx.output, node, args, kwargs, nnmodule)
   1680             )
   1681     except Unsupported:

[/usr/local/lib/python3.10/dist-packages/torch/_dynamo/utils.py](https://localhost:8080/#) in run_node(tracer, node, args, kwargs, nnmodule)
   1812             unimplemented(make_error_message(e), from_exc=e)
   1813         except Exception as e:
-> 1814             raise RuntimeError(make_error_message(e)).with_traceback(
   1815                 e.__traceback__
   1816             ) from e

[/usr/local/lib/python3.10/dist-packages/torch/_dynamo/utils.py](https://localhost:8080/#) in run_node(tracer, node, args, kwargs, nnmodule)
   1794         try:
   1795             if op == "call_function":
-> 1796                 return node.target(*args, **kwargs)
   1797             elif op == "call_method":
   1798                 return getattr(args[0], node.target)(*args[1:], **kwargs)

[/usr/local/lib/python3.10/dist-packages/torch/utils/_stats.py](https://localhost:8080/#) in wrapper(*args, **kwargs)
     18             simple_call_counter[fn.__qualname__] = 0
     19         simple_call_counter[fn.__qualname__] = simple_call_counter[fn.__qualname__] + 1
---> 20         return fn(*args, **kwargs)
     21     return wrapper

[/usr/local/lib/python3.10/dist-packages/torch/_subclasses/fake_tensor.py](https://localhost:8080/#) in __torch_dispatch__(self, func, types, args, kwargs)
    860         ), func
    861         try:
--> 862             return self.dispatch(func, types, args, kwargs)
    863         except TypeError:
    864             log.exception("fake tensor raised TypeError")

[/usr/local/lib/python3.10/dist-packages/torch/_subclasses/fake_tensor.py](https://localhost:8080/#) in dispatch(self, func, types, args, kwargs)
   1205 
   1206         if self.cache_enabled:
-> 1207             return self._cached_dispatch_impl(func, types, args, kwargs)
   1208         else:
   1209             return self._dispatch_impl(func, types, args, kwargs)

[/usr/local/lib/python3.10/dist-packages/torch/_subclasses/fake_tensor.py](https://localhost:8080/#) in _cached_dispatch_impl(self, func, types, args, kwargs)
    938 
    939         if output is unassigned:
--> 940             output = self._dispatch_impl(func, types, args, kwargs)
    941 
    942         return output

[/usr/local/lib/python3.10/dist-packages/torch/_subclasses/fake_tensor.py](https://localhost:8080/#) in _dispatch_impl(self, func, types, args, kwargs)
   1422         try:
   1423             with in_kernel_invocation_manager(self):
-> 1424                 r = func(*args, **kwargs)
   1425         except NotImplementedError as not_implemented_error:
   1426             return maybe_run_unsafe_fallback(not_implemented_error)

[/usr/local/lib/python3.10/dist-packages/torch/_ops.py](https://localhost:8080/#) in __call__(self_, *args, **kwargs)
    598         # use `self_` to avoid naming collide with aten ops arguments that
    599         # are named "self". This way, all the aten ops can be called by kwargs.
--> 600         return self_._op(*args, **kwargs)
    601 
    602     def __hash__(self):

[/usr/local/lib/python3.10/dist-packages/torch/utils/_stats.py](https://localhost:8080/#) in wrapper(*args, **kwargs)
     18             simple_call_counter[fn.__qualname__] = 0
     19         simple_call_counter[fn.__qualname__] = simple_call_counter[fn.__qualname__] + 1
---> 20         return fn(*args, **kwargs)
     21     return wrapper

[/usr/local/lib/python3.10/dist-packages/torch/_subclasses/fake_tensor.py](https://localhost:8080/#) in __torch_dispatch__(cls, func, types, args, kwargs)
    545 
    546         with fake_mode:  # type: ignore[attr-defined]
--> 547             return func(*args, **kwargs)
    548 
    549     @staticmethod

[/usr/local/lib/python3.10/dist-packages/torch/_ops.py](https://localhost:8080/#) in __call__(self_, *args, **kwargs)
    598         # use `self_` to avoid naming collide with aten ops arguments that
    599         # are named "self". This way, all the aten ops can be called by kwargs.
--> 600         return self_._op(*args, **kwargs)
    601 
    602     def __hash__(self):

TorchRuntimeError: Failed running call_function <built-in method topk of type object at 0x7d5191f98460>(*(FakeTensor(..., size=(s0,)), 3), **{'out': (FakeTensor(..., size=(3,)), FakeTensor(..., size=(3,), dtype=torch.int64))}):
Multiple dispatch failed for 'torch.ops.aten.size'; all __torch_dispatch__ handlers returned NotImplemented:

  - tensor subclass <class 'torch._subclasses.fake_tensor.FakeTensor'>

For more information, try re-running with TORCH_LOGS=not_implemented

from user code:
   File "<ipython-input-3-db886c9eeb8c>", line 2, in topk_func
    torch.topk(input, k, out=out)

Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information


You can suppress this exception and fall back to eager by setting:
    import torch._dynamo
    torch._dynamo.config.suppress_errors = True

Minified repro

topk

def topk_func(input, k, out):
  torch.topk(input, k, out=out)

values = torch.empty(3)
indices = torch.empty(3, dtype=torch.long)
opt_model = torch.compile(topk_func)

x = torch.arange(1., 6.)
opt_model(x, 3, out=(values, indices))

x = torch.arange(1., 8.)
opt_model(x, 3, out=(values, indices))

bmm

def bmm_func(input, mat, out):
  torch.bmm(input, mat, out=out)

opt_model = torch.compile(bmm_func)

inp1 = torch.randn(10, 3, 4)
mat1 = torch.randn(10, 4, 5)
out1 = torch.empty(10, 3, 5)
opt_model(inp1, mat1, out1)

inp2 = torch.randn(12, 4, 5)
mat2 = torch.randn(12, 5, 6)
out2 = torch.empty(12, 4, 6)
opt_model(inp2, mat2, out2)

max

def max_func(input, out):
  torch.max(input, 0, keepdim=True, out=out)

max = torch.empty(1)
max_indices = torch.empty(1, dtype=torch.long)

opt_model = torch.compile(max_func)

inp1 = torch.randn(4)
opt_model(inp1, out=(max, max_indices))

inp2 = torch.randn(5)
opt_model(inp2, out=(max, max_indices))

Versions

PyTorch version: 2.4.0.dev20240326+cpu
Is debug build: False
CUDA used to build PyTorch: Could not collect
ROCM used to build PyTorch: N/A

OS: Ubuntu 22.04.3 LTS (x86_64)
GCC version: (Ubuntu 11.4.0-1ubuntu1~22.04) 11.4.0
Clang version: 14.0.0-1ubuntu1.1
CMake version: version 3.27.9
Libc version: glibc-2.35

Python version: 3.10.12 (main, Nov 20 2023, 15:14:05) [GCC 11.4.0] (64-bit runtime)
Python platform: Linux-6.1.58+-x86_64-with-glibc2.35
Is CUDA available: False
CUDA runtime version: 12.2.140
CUDA_MODULE_LOADING set to: N/A
GPU models and configuration: Could not collect
Nvidia driver version: Could not collect
cuDNN version: Probably one of the following:
/usr/lib/x86_64-linux-gnu/libcudnn.so.8.9.6
/usr/lib/x86_64-linux-gnu/libcudnn_adv_infer.so.8.9.6
/usr/lib/x86_64-linux-gnu/libcudnn_adv_train.so.8.9.6
/usr/lib/x86_64-linux-gnu/libcudnn_cnn_infer.so.8.9.6
/usr/lib/x86_64-linux-gnu/libcudnn_cnn_train.so.8.9.6
/usr/lib/x86_64-linux-gnu/libcudnn_ops_infer.so.8.9.6
/usr/lib/x86_64-linux-gnu/libcudnn_ops_train.so.8.9.6
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

CPU:
Architecture:                       x86_64
CPU op-mode(s):                     32-bit, 64-bit
Address sizes:                      46 bits physical, 48 bits virtual
Byte Order:                         Little Endian
CPU(s):                             2
On-line CPU(s) list:                0,1
Vendor ID:                          GenuineIntel
Model name:                         Intel(R) Xeon(R) CPU @ 2.20GHz
CPU family:                         6
Model:                              79
Thread(s) per core:                 2
Core(s) per socket:                 1
Socket(s):                          1
Stepping:                           0
BogoMIPS:                           4399.99
Flags:                              fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ss ht syscall nx pdpe1gb rdtscp lm constant_tsc rep_good nopl xtopology nonstop_tsc cpuid tsc_known_freq pni pclmulqdq ssse3 fma cx16 pcid sse4_1 sse4_2 x2apic movbe popcnt aes xsave avx f16c rdrand hypervisor lahf_lm abm 3dnowprefetch invpcid_single ssbd ibrs ibpb stibp fsgsbase tsc_adjust bmi1 hle avx2 smep bmi2 erms invpcid rtm rdseed adx smap xsaveopt arat md_clear arch_capabilities
Hypervisor vendor:                  KVM
Virtualization type:                full
L1d cache:                          32 KiB (1 instance)
L1i cache:                          32 KiB (1 instance)
L2 cache:                           256 KiB (1 instance)
L3 cache:                           55 MiB (1 instance)
NUMA node(s):                       1
NUMA node0 CPU(s):                  0,1
Vulnerability Gather data sampling: Not affected
Vulnerability Itlb multihit:        Not affected
Vulnerability L1tf:                 Mitigation; PTE Inversion
Vulnerability Mds:                  Vulnerable; SMT Host state unknown
Vulnerability Meltdown:             Vulnerable
Vulnerability Mmio stale data:      Vulnerable
Vulnerability Retbleed:             Vulnerable
Vulnerability Spec rstack overflow: Not affected
Vulnerability Spec store bypass:    Vulnerable
Vulnerability Spectre v1:           Vulnerable: __user pointer sanitization and usercopy barriers only; no swapgs barriers
Vulnerability Spectre v2:           Vulnerable, IBPB: disabled, STIBP: disabled, PBRSB-eIBRS: Not affected
Vulnerability Srbds:                Not affected
Vulnerability Tsx async abort:      Vulnerable

Versions of relevant libraries:
[pip3] numpy==1.25.2
[pip3] torch==2.4.0.dev20240326+cpu
[pip3] torchaudio==2.2.1+cu121
[pip3] torchdata==0.7.1
[pip3] torchsummary==1.5.1
[pip3] torchtext==0.17.1
[pip3] torchvision==0.17.1+cu121
[pip3] triton==2.2.0
[conda] Could not collect

cc @ezyang @msaroufim @bdhirsh @anijain2305 @zou3519 @chauhang @eellison

@ezyang
Copy link
Contributor

ezyang commented Mar 28, 2024

Yeah there's some sort of problem here. It's also the reason why test_make_fx_symbolic_exhaustive_nn_functional_fractional_max_pool3d_cpu_float32 is xfailed

@ezyang ezyang added module: fakeTensor module: pt2-dispatcher PT2 dispatcher-related issues (e.g., aotdispatch, functionalization, faketensor, custom-op, labels Mar 28, 2024
@yf225 yf225 added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Mar 30, 2024
@ezyang
Copy link
Contributor

ezyang commented Apr 23, 2024

What looks like is happening is that we're attempting to run the underlying meta for topk (eg), and somehow we're redispatching to fake tensor mode (!!!)

ezyang added a commit that referenced this issue Apr 23, 2024
Fixes #122772

Signed-off-by: Edward Z. Yang <ezyang@meta.com>

ghstack-source-id: e52ae2e2d746732f02b498f3efffd47e307ff5b7
Pull Request resolved: #124760
ezyang added a commit that referenced this issue Apr 23, 2024
Fixes #122772

Signed-off-by: Edward Z. Yang <ezyang@meta.com>

ghstack-source-id: e52ae2e2d746732f02b498f3efffd47e307ff5b7
Pull Request resolved: #124760
ezyang added a commit that referenced this issue Apr 23, 2024
Fixes #122772

Signed-off-by: Edward Z. Yang <ezyang@meta.com>

ghstack-source-id: 7a61c58d8f06150e4a2400b6e1ad6faa9551d83a
Pull Request resolved: #124760
ezyang added a commit that referenced this issue Apr 24, 2024
Fixes #122772

Signed-off-by: Edward Z. Yang <ezyang@meta.com>

ghstack-source-id: c226554fab2945fd69f24db5d756876f2b72f72c
Pull Request resolved: #124760
ezyang added a commit that referenced this issue Apr 24, 2024
Fixes #122772

Signed-off-by: Edward Z. Yang <ezyang@meta.com>

ghstack-source-id: 32407d6b001ded01eb96a9ca44c260f4f2ce4b71
Pull Request resolved: #124760
OnlyFor pushed a commit to OnlyFor/pytorch that referenced this issue Apr 24, 2024
Fixes pytorch#122772

Signed-off-by: Edward Z. Yang <ezyang@meta.com>

ghstack-source-id: 88d5a17ea80313d8a506616ceb618a97eec7f2a4
Pull Request resolved: pytorch#124760
alat-rights pushed a commit to alat-rights/pytorch that referenced this issue Apr 26, 2024
…ypes from toInt/etc in IValue (pytorch#124760)

Fixes pytorch#122772

Signed-off-by: Edward Z. Yang <ezyang@meta.com>
Pull Request resolved: pytorch#124760
Approved by: https://github.com/albanD, https://github.com/eellison
carmocca pushed a commit to carmocca/pytorch that referenced this issue Apr 29, 2024
…ypes from toInt/etc in IValue (pytorch#124760)

Fixes pytorch#122772

Signed-off-by: Edward Z. Yang <ezyang@meta.com>
Pull Request resolved: pytorch#124760
Approved by: https://github.com/albanD, https://github.com/eellison
andoorve pushed a commit to andoorve/pytorch that referenced this issue May 1, 2024
…ypes from toInt/etc in IValue (pytorch#124760)

Fixes pytorch#122772

Signed-off-by: Edward Z. Yang <ezyang@meta.com>
Pull Request resolved: pytorch#124760
Approved by: https://github.com/albanD, https://github.com/eellison
pytorch-bot bot pushed a commit that referenced this issue May 3, 2024
…ypes from toInt/etc in IValue (#124760)

Fixes #122772

Signed-off-by: Edward Z. Yang <ezyang@meta.com>
Pull Request resolved: #124760
Approved by: https://github.com/albanD, https://github.com/eellison
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
module: dynamic shapes module: fakeTensor module: pt2-dispatcher PT2 dispatcher-related issues (e.g., aotdispatch, functionalization, faketensor, custom-op, oncall: pt2 triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

No branches or pull requests

3 participants