Skip to content

Commit

Permalink
Fix some silly Inductor bugs
Browse files Browse the repository at this point in the history
Should probably figure out how to get type checking going, would have
caught these cases.

Discovered in pursuit of #91719
though this is not enough.

Signed-off-by: Edward Z. Yang <ezyang@meta.com>

[ghstack-poisoned]
  • Loading branch information
ezyang committed Jan 25, 2023
1 parent b399007 commit 04de74e
Show file tree
Hide file tree
Showing 3 changed files with 7 additions and 2 deletions.
4 changes: 3 additions & 1 deletion torch/_inductor/mkldnn.py
Expand Up @@ -488,7 +488,9 @@ def fused_linear_binary_eval(linear: nn.Module, attr: str, input_size: list):

def mkldnn_fuse_fx(gm: torch.fx.GraphModule, example_inputs):
is_cpu = all(
example_input.device == torch.device("cpu") for example_input in example_inputs
example_input.device == torch.device("cpu")
for example_input in example_inputs
if isinstance(example_input, torch.Tensor)
)

# make sure the autograd is disabled.
Expand Down
4 changes: 3 additions & 1 deletion torch/_inductor/overrides.py
Expand Up @@ -63,7 +63,9 @@ def replace_fx(gm: torch.fx.GraphModule):

def fuse_fx(gm: torch.fx.GraphModule, example_inputs):
is_cpu = all(
example_input.device == torch.device("cpu") for example_input in example_inputs
example_input.device == torch.device("cpu")
for example_input in example_inputs
if isinstance(example_input, torch.Tensor)
)

fake_mode = fake_mode_from_tensors(example_inputs)
Expand Down
1 change: 1 addition & 0 deletions torch/_inductor/utils.py
Expand Up @@ -103,6 +103,7 @@ def convert_shape_to_symint(
return lst
if all(isinstance(i, sympy.Integer) for i in lst):
return [int(i) for i in lst]
from .virtualized import V
return [V.graph.sizevars.shape_env.create_symintnode(i) for i in lst]


Expand Down

0 comments on commit 04de74e

Please sign in to comment.