Skip to content

Commit

Permalink
Make CPU inductor work with dynamic shapes
Browse files Browse the repository at this point in the history
These errors were found by looking at wav2vec2

See #91719

Signed-off-by: Edward Z. Yang <ezyangmeta.com>

ghstack-source-id: 51b6d644b4e30b9137354a4fb5eece5534d3412e
Pull Request resolved: #93077
  • Loading branch information
ezyang committed Jan 26, 2023
1 parent 7012d98 commit d50087d
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 2 deletions.
6 changes: 5 additions & 1 deletion torch/_inductor/mkldnn.py
Expand Up @@ -8,6 +8,8 @@
import torch.nn as nn
import torch.nn.functional as F

from torch import sym_int

from torch._dynamo.utils import fake_mode_from_tensors
from torch.fx.experimental.optimization import (
matches_module_pattern,
Expand Down Expand Up @@ -351,7 +353,9 @@ def __init__(self, linear: nn.Module, input_size: list):

def _update_module_params(self, linear, input_size):
self.__dict__ = copy.deepcopy(linear.__dict__)
self.batch_size = int(reduce(lambda x, y: x * y, input_size) / input_size[-1])
self.batch_size = sym_int(
reduce(lambda x, y: x * y, input_size) / input_size[-1]
)
self.packed_weight = torch.nn.Parameter(
torch.ops.mkl._mkl_reorder_linear_weight(
self.weight.to_mkldnn(), self.batch_size
Expand Down
4 changes: 3 additions & 1 deletion torch/_inductor/overrides.py
Expand Up @@ -4,6 +4,7 @@
import weakref

import torch
import torch._dynamo.config as dynamo_config
import torch.nn as nn
from torch import _prims
from torch._dynamo.utils import fake_mode_from_tensors
Expand Down Expand Up @@ -87,7 +88,8 @@ def fuse_fx(gm: torch.fx.GraphModule, example_inputs):
gm = remove_identity(gm)
gm = fuse_conv_bn(gm)
# do mkldnn fusion(conv(linear)+unary(binary)
gm = mkldnn_fuse_fx(gm, example_inputs)
if not dynamo_config.dynamic_shapes:
gm = mkldnn_fuse_fx(gm, example_inputs)
return gm


Expand Down

0 comments on commit d50087d

Please sign in to comment.