Skip to content

Commit

Permalink
[inductor] convert layout of conv weight ahead of time for inference
Browse files Browse the repository at this point in the history
ghstack-source-id: 8026814f2d984c5454d8fc28b2bab7aa0fc833de
Pull Request resolved: #103642
  • Loading branch information
shunting314 committed Jun 16, 2023
1 parent 49dcf48 commit 1c23dca
Show file tree
Hide file tree
Showing 5 changed files with 101 additions and 16 deletions.
49 changes: 49 additions & 0 deletions test/inductor/test_inductor_freezing.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import torch

import torch._dynamo
from torch import nn
from torch._inductor import config
from torch._inductor.utils import run_and_get_code
from torch.testing import FileCheck
Expand Down Expand Up @@ -229,6 +230,54 @@ def foo(mod, inp):
self.assertEqual(eager, compiled)
self.assertTrue(weight_ref() is None)

def test_conv_weight_layout_convert(self):
class Model(torch.nn.Module):
def __init__(self):
super().__init__()
self.conv = nn.Conv2d(
3, 128, kernel_size=3, padding=1, stride=1, bias=False
)

def forward(self, x):
return self.conv(x)

def get_example_inputs(self):
return (torch.rand(2, 3, 5, 5).cuda(),)

from torch._inductor.compile_fx import compile_fx, compile_fx_inner

nconv = 0

def my_inner_compile(gm, example_inputs, *args, **kwargs):
out = compile_fx_inner(gm, example_inputs, *args, **kwargs)

nonlocal nconv
convs = [n for n in gm.graph.nodes if n.target == aten.convolution.default]
nconv += len(convs)
for conv in convs:
weight_node = conv.args[1]
weight_const_tensor = getattr(gm, weight_node.target)
self.assertTrue(
weight_const_tensor.is_contiguous(memory_format=torch.channels_last)
)
self.assertTrue(
weight_node.meta["val"].is_contiguous(
memory_format=torch.channels_last
)
)

return out

mod = torch.compile(
Model().cuda(),
backend=functools.partial(compile_fx, inner_compile=my_inner_compile),
)
inp = mod.get_example_inputs()
with torch.no_grad():
mod(*inp)

self.assertTrue(nconv == 1)


if HAS_CPU and not torch.backends.mps.is_available():

Expand Down
8 changes: 0 additions & 8 deletions torch/_functorch/aot_autograd.py
Original file line number Diff line number Diff line change
Expand Up @@ -2783,14 +2783,6 @@ def aot_dispatch_autograd(flat_fn, flat_args: List[Any], aot_config: AOTConfig,
torch._guards.TracingContext.get().fw_metadata = fw_metadata


# the compiler need to use this field to find the original modol outputs
# from the AOTAutograd fwd module's outputs. Thus compiler can make sure
# optimizations like layout optimization does not change those tensors'
# layout.
# TODO once https://github.com/pytorch/pytorch/pull/100652/files#r1212002707 is in
# change to access fw_metadata from the global tracing context.
fw_module.meta["original_output_start_index"] = fw_metadata.num_mutated_inputs

compiled_fw_func = aot_config.fw_compiler(
fw_module, adjusted_flat_args
)
Expand Down
12 changes: 7 additions & 5 deletions torch/_inductor/compile_fx.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,6 +244,7 @@ def compile_fx_inner(
is_inference=False,
boxed_forward_device_index=None,
user_visible_outputs=frozenset(),
layout_opt=None,
):
if is_tf32_warning_applicable(gm):
_warn_tf32_disabled()
Expand Down Expand Up @@ -314,6 +315,7 @@ def compile_fx_inner(
cpp_wrapper=cpp_wrapper,
aot_mode=aot_mode,
user_visible_outputs=user_visible_outputs,
layout_opt=layout_opt,
)
with V.set_graph_handler(graph):
graph.run(*example_inputs)
Expand Down Expand Up @@ -674,7 +676,7 @@ def fw_compiler_freezing(
graph_id,
forward_device,
):
from torch._inductor.freezing import freeze
from torch._inductor.freezing import convert_conv_weights_to_channels_last, freeze

# partition_fn won't be called
joint_graph_passes(aot_autograd_model)
Expand All @@ -684,6 +686,9 @@ def fw_compiler_freezing(
aot_autograd_model,
fw_metadata=torch._guards.TracingContext.get().fw_metadata,
)
layout_opt = GraphLowering.decide_layout_opt(aot_autograd_model)
if layout_opt:
convert_conv_weights_to_channels_last(aot_autograd_model)

aot_example_inputs = [aot_example_inputs[ind] for ind in preserved_arg_indices]
num_fixed = len(preserved_arg_indices) - num_example_inputs
Expand All @@ -700,6 +705,7 @@ def fw_compiler_freezing(
graph_id=graph_id,
is_inference=True,
boxed_forward_device_index=forward_device,
layout_opt=layout_opt,
)

# Need to drop the args we have constant-ified.
Expand Down Expand Up @@ -821,12 +827,8 @@ def fw_compiler_base(model: torch.fx.GraphModule, example_inputs, is_inference):
orig_model_outputs_node.args
)
num_orig_model_outputs = len(orig_model_outputs)
original_output_start_index = model.meta.get(
"original_output_start_index", 0
)
else:
num_orig_model_outputs = num_model_outputs
original_output_start_index = 0

assert num_orig_model_outputs <= num_model_outputs

Expand Down
39 changes: 39 additions & 0 deletions torch/_inductor/freezing.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,14 @@

import torch
import torch.utils._pytree as pytree
from torch import nn
from torch._dynamo.utils import dynamo_timed
from torch.fx.passes.shape_prop import _extract_tensor_metadata
from torch.utils._mode_utils import no_dispatch
from . import config

aten = torch.ops.aten


def replace_node_with_constant(gm, node, constant):
g = gm.graph
Expand Down Expand Up @@ -224,3 +230,36 @@ def discard_traced_gm_params(mod):
e_t.requires_grad_(True)
e_t._is_param = True
setattr(mod, attr_name, e_t)


@dynamo_timed
def convert_conv_weights_to_channels_last(gm):
"""
Convert 4d convolution weight tensor to channels last format.
This method assumes the graph is already freezed.
"""
convs = [n for n in gm.graph.nodes if n.target == aten.convolution.default]
for conv in convs:
weight_node = conv.args[1]
# is a constant tensor
if weight_node.op == "get_attr":
param_tensor = getattr(gm, weight_node.target)
if len(param_tensor.shape) != 4:
# not a 4d tensor, skip
continue
with no_dispatch():
cl_param_tensor = param_tensor.to(memory_format=torch.channels_last)
if isinstance(param_tensor, nn.Parameter):
cl_param_tensor = nn.Parameter(cl_param_tensor)
if cl_param_tensor is not param_tensor:
setattr(gm, weight_node.target, cl_param_tensor)

# Even though inductor does not use meta['val'] or meta['tensor_meta']
# for get_attr node, we still update them to be consistent.
weight_node.meta["val"] = weight_node.meta["val"].to(
memory_format=torch.channels_last
)
weight_node.meta["tensor_meta"] = _extract_tensor_metadata(
weight_node.meta["val"]
)
9 changes: 6 additions & 3 deletions torch/_inductor/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,10 +150,13 @@ def __init__(
cpp_wrapper=False,
aot_mode=False,
user_visible_outputs=frozenset(),
layout_opt=None,
):
super().__init__(gm)

self.layout_opt = self.decide_layout_opt()
self.layout_opt = (
layout_opt if layout_opt is not None else self.decide_layout_opt(gm)
)
self.num_channels_last_conv = 0

self.extra_traceback = False # we do our own error wrapping
Expand Down Expand Up @@ -195,15 +198,15 @@ def __init__(
self._warned_fallback = {"aten.convolution_backward"}
self.user_visible_outputs = user_visible_outputs

def decide_layout_opt(self) -> bool:
@staticmethod
def decide_layout_opt(gm) -> bool:
"""
Decide if we should enable layout optimization for this graph based on
heuristics.
"""
if not config.layout_optimization:
return False

gm = self.module
conv_nodes = [
n for n in gm.graph.nodes if n.target == torch.ops.aten.convolution.default
]
Expand Down

0 comments on commit 1c23dca

Please sign in to comment.