Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[inductor] convert layout of conv weight ahead of time for inference #103642

Closed
wants to merge 18 commits into from
Closed
Show file tree
Hide file tree
Changes from 7 commits
Commits
Show all changes
18 commits
Select commit Hold shift + click to select a range
3b03e06
[wip][inductor] convert layout of conv weight ahead of time
shunting314 Jun 15, 2023
ca24229
Update on "[wip][inductor] convert layout of conv weight ahead of tim…
shunting314 Jun 15, 2023
379fabe
Update on "[wip][inductor] convert layout of conv weight ahead of tim…
shunting314 Jun 15, 2023
53e9ed9
Update on "[wip][inductor] convert layout of conv weight ahead of tim…
shunting314 Jun 16, 2023
27b08b4
Update on "[wip][inductor] convert layout of conv weight ahead of tim…
shunting314 Jun 16, 2023
7c200c3
Update on "[inductor] convert layout of conv weight ahead of time for…
shunting314 Jun 16, 2023
0122768
Update on "[inductor] convert layout of conv weight ahead of time for…
shunting314 Jun 16, 2023
230405e
Update on "[inductor] convert layout of conv weight ahead of time for…
shunting314 Jun 22, 2023
dd3165b
Update on "[inductor] convert layout of conv weight ahead of time for…
shunting314 Jun 22, 2023
0e0d73e
Update on "[inductor] convert layout of conv weight ahead of time for…
shunting314 Jun 23, 2023
bea85f7
Update on "[inductor] convert layout of conv weight ahead of time for…
shunting314 Jun 23, 2023
6c3a3c6
Update on "[inductor] convert layout of conv weight ahead of time for…
shunting314 Jun 26, 2023
aeb658b
Update on "[inductor] convert layout of conv weight ahead of time for…
shunting314 Jun 27, 2023
d5509a3
Update on "[inductor] convert layout of conv weight ahead of time for…
shunting314 Jun 27, 2023
5d8c954
Update on "[inductor] convert layout of conv weight ahead of time for…
shunting314 Jun 27, 2023
92d213c
Update on "[inductor] convert layout of conv weight ahead of time for…
shunting314 Jun 28, 2023
24b118b
Update on "[inductor] convert layout of conv weight ahead of time for…
shunting314 Jun 28, 2023
b8c4556
Update on "[inductor] convert layout of conv weight ahead of time for…
shunting314 Jun 28, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
58 changes: 58 additions & 0 deletions test/inductor/test_inductor_freezing.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import torch

import torch._dynamo
from torch import nn
from torch._inductor import config
from torch._inductor.utils import run_and_get_code
from torch.testing import FileCheck
Expand Down Expand Up @@ -229,6 +230,63 @@ def foo(mod, inp):
self.assertEqual(eager, compiled)
self.assertTrue(weight_ref() is None)

def test_conv_weight_layout_convert(self):
class Model(torch.nn.Module):
def __init__(self):
super().__init__()
self.conv = nn.Conv2d(
3, 128, kernel_size=3, padding=1, stride=1, bias=False
)

def forward(self, x):
return self.conv(x)

@staticmethod
def get_example_inputs():
return (torch.rand(2, 3, 5, 5).to(self.device),)

from torch._inductor.compile_fx import compile_fx, compile_fx_inner

nconv = 0

def my_inner_compile(gm, example_inputs, *args, **kwargs):
shunting314 marked this conversation as resolved.
Show resolved Hide resolved
out = compile_fx_inner(gm, example_inputs, *args, **kwargs)

nonlocal nconv
convs = [n for n in gm.graph.nodes if n.target == aten.convolution.default]
nconv += len(convs)
for conv in convs:
weight_node = conv.args[1]
weight_const_tensor = getattr(gm, weight_node.target)
self.assertTrue(
weight_const_tensor.is_contiguous(memory_format=torch.channels_last)
)
self.assertTrue(
weight_node.meta["val"].is_contiguous(
memory_format=torch.channels_last
)
)

return out

mod = torch.compile(
Model().eval().to(self.device),
backend=functools.partial(compile_fx, inner_compile=my_inner_compile),
)
inp = mod.get_example_inputs()
with torch.no_grad():
mod(*inp)

if self.device == "cuda":
self.assertTrue(nconv == 1)
else:
assert self.device == "cpu"
# For CPU, we get torch.ops.mkldnn._convolution_pointwise.default
# in the joint graph rather than torch.ops.aten.convolution.default.
# Currently we only handle aten.convolution.default in layout
# optimization. That's why the count is 0 here for CPU.
self.assertTrue(nconv == 0)


if HAS_CPU and not torch.backends.mps.is_available():

Expand Down
8 changes: 0 additions & 8 deletions torch/_functorch/aot_autograd.py
Original file line number Diff line number Diff line change
Expand Up @@ -2783,14 +2783,6 @@ def aot_dispatch_autograd(flat_fn, flat_args: List[Any], aot_config: AOTConfig,
torch._guards.TracingContext.get().fw_metadata = fw_metadata


# the compiler need to use this field to find the original modol outputs
# from the AOTAutograd fwd module's outputs. Thus compiler can make sure
# optimizations like layout optimization does not change those tensors'
# layout.
# TODO once https://github.com/pytorch/pytorch/pull/100652/files#r1212002707 is in
# change to access fw_metadata from the global tracing context.
fw_module.meta["original_output_start_index"] = fw_metadata.num_mutated_inputs

compiled_fw_func = aot_config.fw_compiler(
fw_module, adjusted_flat_args
)
Expand Down
12 changes: 7 additions & 5 deletions torch/_inductor/compile_fx.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,6 +244,7 @@ def compile_fx_inner(
is_inference=False,
boxed_forward_device_index=None,
user_visible_outputs=frozenset(),
layout_opt=None,
):
if is_tf32_warning_applicable(gm):
_warn_tf32_disabled()
Expand Down Expand Up @@ -314,6 +315,7 @@ def compile_fx_inner(
cpp_wrapper=cpp_wrapper,
aot_mode=aot_mode,
user_visible_outputs=user_visible_outputs,
layout_opt=layout_opt,
)
with V.set_graph_handler(graph):
graph.run(*example_inputs)
Expand Down Expand Up @@ -674,7 +676,7 @@ def fw_compiler_freezing(
graph_id,
forward_device,
):
from torch._inductor.freezing import freeze
from torch._inductor.freezing import convert_conv_weights_to_channels_last, freeze

# partition_fn won't be called
joint_graph_passes(aot_autograd_model)
Expand All @@ -684,6 +686,9 @@ def fw_compiler_freezing(
aot_autograd_model,
fw_metadata=torch._guards.TracingContext.get().fw_metadata,
)
layout_opt = GraphLowering.decide_layout_opt(aot_autograd_model)
if layout_opt:
convert_conv_weights_to_channels_last(aot_autograd_model)

aot_example_inputs = [aot_example_inputs[ind] for ind in preserved_arg_indices]
num_fixed = len(preserved_arg_indices) - num_example_inputs
Expand All @@ -700,6 +705,7 @@ def fw_compiler_freezing(
graph_id=graph_id,
is_inference=True,
boxed_forward_device_index=forward_device,
layout_opt=layout_opt,
)

# Need to drop the args we have constant-ified.
Expand Down Expand Up @@ -821,12 +827,8 @@ def fw_compiler_base(model: torch.fx.GraphModule, example_inputs, is_inference):
orig_model_outputs_node.args
)
num_orig_model_outputs = len(orig_model_outputs)
original_output_start_index = model.meta.get(
"original_output_start_index", 0
)
else:
num_orig_model_outputs = num_model_outputs
original_output_start_index = 0

assert num_orig_model_outputs <= num_model_outputs

Expand Down
39 changes: 39 additions & 0 deletions torch/_inductor/freezing.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,14 @@

import torch
import torch.utils._pytree as pytree
from torch import nn
from torch._dynamo.utils import dynamo_timed
from torch.fx.passes.shape_prop import _extract_tensor_metadata
from torch.utils._mode_utils import no_dispatch
from . import config

aten = torch.ops.aten


def replace_node_with_constant(gm, node, constant):
g = gm.graph
Expand Down Expand Up @@ -224,3 +230,36 @@ def discard_traced_gm_params(mod):
e_t.requires_grad_(True)
e_t._is_param = True
setattr(mod, attr_name, e_t)


@dynamo_timed
def convert_conv_weights_to_channels_last(gm):
shunting314 marked this conversation as resolved.
Show resolved Hide resolved
"""
Convert 4d convolution weight tensor to channels last format.

This method assumes the graph is already freezed.
"""
convs = [n for n in gm.graph.nodes if n.target == aten.convolution.default]
for conv in convs:
weight_node = conv.args[1]
# is a constant tensor
if weight_node.op == "get_attr":
param_tensor = getattr(gm, weight_node.target)
if len(param_tensor.shape) != 4:
# not a 4d tensor, skip
continue
with no_dispatch():
cl_param_tensor = param_tensor.to(memory_format=torch.channels_last)
shunting314 marked this conversation as resolved.
Show resolved Hide resolved
if isinstance(param_tensor, nn.Parameter):
cl_param_tensor = nn.Parameter(cl_param_tensor)
if cl_param_tensor is not param_tensor:
setattr(gm, weight_node.target, cl_param_tensor)

# Even though inductor does not use meta['val'] or meta['tensor_meta']
# for get_attr node, we still update them to be consistent.
weight_node.meta["val"] = weight_node.meta["val"].to(
shunting314 marked this conversation as resolved.
Show resolved Hide resolved
memory_format=torch.channels_last
)
weight_node.meta["tensor_meta"] = _extract_tensor_metadata(
weight_node.meta["val"]
)
9 changes: 6 additions & 3 deletions torch/_inductor/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,10 +150,13 @@ def __init__(
cpp_wrapper=False,
aot_mode=False,
user_visible_outputs=frozenset(),
layout_opt=None,
):
super().__init__(gm)

self.layout_opt = self.decide_layout_opt()
self.layout_opt = (
layout_opt if layout_opt is not None else self.decide_layout_opt(gm)
)
self.num_channels_last_conv = 0

self.extra_traceback = False # we do our own error wrapping
Expand Down Expand Up @@ -195,15 +198,15 @@ def __init__(
self._warned_fallback = {"aten.convolution_backward"}
self.user_visible_outputs = user_visible_outputs

def decide_layout_opt(self) -> bool:
@staticmethod
def decide_layout_opt(gm) -> bool:
"""
Decide if we should enable layout optimization for this graph based on
heuristics.
"""
if not config.layout_optimization:
return False

gm = self.module
conv_nodes = [
n for n in gm.graph.nodes if n.target == torch.ops.aten.convolution.default
]
Expand Down