This repository was archived by the owner on Aug 1, 2025. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 129
This repository was archived by the owner on Aug 1, 2025. It is now read-only.
[inductor] float64 KeyError #1746
Copy link
Copy link
Closed
Labels
bugSomething isn't workingSomething isn't working
Description
🐛 Describe the bug
Following repro fails
Error logs
Traceback (most recent call last):
File "/scratch/anijain/work/pytorch/repro.py", line 45, in <module>
compiled = compile_fx_inner(mod, args)
File "/scratch/anijain/work/pytorch/torch/_dynamo/debug_utils.py", line 471, in debug_wrapper
compiled_fn = compiler_fn(gm, example_inputs, **kwargs)
File "/scratch/anijain/work/pytorch/torch/_inductor/debug.py", line 178, in inner
return fn(*args, **kwargs)
File "/scratch/anijain/work/env/lib/python3.9/contextlib.py", line 79, in inner
return func(*args, **kwds)
File "/scratch/anijain/work/pytorch/torch/_inductor/compile_fx.py", line 123, in compile_fx_inner
compiled_fn = graph.compile_to_fn()
File "/scratch/anijain/work/pytorch/torch/_inductor/graph.py", line 338, in compile_to_fn
return self.compile_to_module().call
File "/scratch/anijain/work/pytorch/torch/_dynamo/utils.py", line 85, in time_wrapper
r = func(*args, **kwargs)
File "/scratch/anijain/work/pytorch/torch/_inductor/graph.py", line 328, in compile_to_module
mod = PyCodeCache.load(code)
File "/scratch/anijain/work/pytorch/torch/_inductor/codecache.py", line 212, in load
exec(code, mod.__dict__, mod.__dict__)
File "/tmp/torchinductor_anijain/h3/ch3eodunx3i3wmnkeoqrc35vsmeyvk464ff3agm4pzftyiopn277.py", line 41, in <module>
async_compile.wait(globals())
File "/scratch/anijain/work/pytorch/torch/_inductor/codecache.py", line 352, in wait
scope[key] = result.result()
File "/scratch/anijain/work/pytorch/torch/_inductor/codecache.py", line 259, in result
self.future.result()
File "/scratch/anijain/work/env/lib/python3.9/concurrent/futures/_base.py", line 446, in result
return self.__get_result()
File "/scratch/anijain/work/env/lib/python3.9/concurrent/futures/_base.py", line 391, in __get_result
raise self._exception
KeyError: 'fp64'
Minified repro
import torch._inductor.overrides
import torch
from torch import tensor, device
import torch.fx as fx
from torch._dynamo.testing import rand_strided
from math import inf
from torch.fx.experimental.proxy_tensor import make_fx
# torch version: 1.14.0a0+git8a8cd09
# torch cuda version: 11.6
# torch git version: 8a8cd092c8537b226f5c38ed88bc07e181b0946c
# CUDA Info:
# nvcc: NVIDIA (R) Cuda compiler driver
# Copyright (c) 2005-2022 NVIDIA Corporation
# Built on Thu_Feb_10_18:23:41_PST_2022
# Cuda compilation tools, release 11.6, V11.6.112
# Build cuda_11.6.r11.6/compiler.30978841_0
# GPU Hardware Info:
# NVIDIA A100-SXM4-40GB : 8
from torch.nn import *
class Repro(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, arg104_1, bmm):
div = torch.ops.aten.div.Tensor(bmm, arg104_1); bmm = arg104_1 = None
return (div,)
args = [((), (), torch.float64, 'cpu'), ((80, 204, 204), (41616, 204, 1), torch.float32, 'cuda')]
args = [rand_strided(sh, st, dt, dev) for (sh, st, dt, dev) in args]
mod = make_fx(Repro().to(device="cuda"))(*args)
mod(*args)
from torch._inductor.compile_fx import compile_fx_inner
from torch._dynamo.debug_utils import same_two_models
compiled = compile_fx_inner(mod, args)
compiled(args)
Metadata
Metadata
Assignees
Labels
bugSomething isn't workingSomething isn't working