Skip to content

Commit

Permalink
[inductor] convert layout of conv weight ahead of time for inference (#…
Browse files Browse the repository at this point in the history
…103642)

This PR handles inference. Will do similar thing for training later.

Some manual testing results shows this can improve inference perf by 2-3% (absolute improvement not relative one).
- convmixer: 4.285x -> 4.309x
- resnet50: 2.170x -> 2.203x

The PR is built upon freezing. Since without freezing, the weight input for a conv node may not be a parameter directly but be the output of precision converting ops. It's so much easier to implement this PR after freezing.

Commands
```
TORCHINDUCTOR_FREEZING=1 python benchmarks/dynamo/timm_models.py --backend inductor --amp --performance --only convmixer_768_32 --inference
```

Pull Request resolved: #103642
Approved by: https://github.com/eellison
  • Loading branch information
shunting314 authored and pytorchmergebot committed Jun 28, 2023
1 parent 044a8e3 commit 98f00f8
Show file tree
Hide file tree
Showing 9 changed files with 292 additions and 26 deletions.
128 changes: 127 additions & 1 deletion test/inductor/test_inductor_freezing.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import contextlib
import functools
import importlib
import itertools
import os
import sys
import unittest
Expand All @@ -10,8 +11,9 @@
import torch

import torch._dynamo
from torch import nn
from torch._inductor import config
from torch._inductor.utils import run_and_get_code
from torch._inductor.utils import override_lowering, run_and_get_code
from torch.testing import FileCheck

# Make the helper files in test/ importable
Expand Down Expand Up @@ -43,6 +45,7 @@

HAS_MULTIGPU = HAS_CUDA and torch.cuda.device_count() >= 2
aten = torch.ops.aten
prims = torch.ops.prims
requires_cuda = functools.partial(unittest.skipIf, not HAS_CUDA, "requires cuda")


Expand Down Expand Up @@ -312,6 +315,129 @@ def foo(mod, inp):
self.assertEqual(eager, compiled)
self.assertTrue(weight_ref() is None)

def test_conv_weight_layout_convert(self):
class Model(torch.nn.Module):
def __init__(self):
super().__init__()
self.conv = nn.Conv2d(
3, 128, kernel_size=3, padding=1, stride=1, bias=False
)

def forward(self, x):
return self.conv(x)

@staticmethod
def get_example_inputs():
return (torch.rand(2, 3, 5, 5).to(self.device),)

from torch._inductor.compile_fx import compile_fx, compile_fx_inner

nconv = 0

def my_inner_compile(gm, example_inputs, *args, **kwargs):
out = compile_fx_inner(gm, example_inputs, *args, **kwargs)

nonlocal nconv
convs = [n for n in gm.graph.nodes if n.target == aten.convolution.default]
nconv += len(convs)
for conv in convs:
weight_node = conv.args[1]
weight_const_tensor = getattr(gm, weight_node.target)
self.assertTrue(
weight_const_tensor.is_contiguous(memory_format=torch.channels_last)
)
self.assertTrue(
weight_node.meta["val"].is_contiguous(
memory_format=torch.channels_last
)
)

return out

mod = torch.compile(
Model().eval().to(self.device),
backend=functools.partial(compile_fx, inner_compile=my_inner_compile),
)
inp = mod.get_example_inputs()
with torch.no_grad():
mod(*inp)

# Only check the assertion for CUDA.
# For CPU, we may get torch.ops.mkldnn._convolution_pointwise.default
# in the joint graph rather than torch.ops.aten.convolution.default.
# Currently we only handle aten.convolution.default in layout
# optimization. That's why the count may be 0 here for CPU.
if self.device == "cuda":
self.assertTrue(nconv == 1)

def test_redundant_clone_for_layout_convert(self):
class Model(torch.nn.Module):
def __init__(self):
super().__init__()
self.conv = nn.Conv2d(
3, 128, kernel_size=3, padding=1, stride=1, bias=False
)

def forward(self, x):
y = x + 1
return self.conv(x), y

@staticmethod
def get_example_inputs():
return (torch.rand(2, 3, 5, 5).to(self.device),)

mod = Model().eval().to(self.device)
inp = mod.get_example_inputs()
with torch.no_grad():
expected_outputs = mod(*inp)

num_same_stride = 0
num_diff_stride = 0

def debug_inductor_force_stride_order(orig_fn, input_tensor, stride):
nonlocal num_same_stride, num_diff_stride
input_tensor.realize()
if tuple(input_tensor.get_stride()) == tuple(stride):
num_same_stride += 1
else:
num_diff_stride += 1
return orig_fn(input_tensor, stride)

with override_lowering(
prims.inductor_force_stride_order.default, debug_inductor_force_stride_order
):
opt_mod = torch.compile(mod)
with torch.no_grad():
actual_outputs = opt_mod(*inp)

self.assertEqual(len(actual_outputs), len(expected_outputs))
self.assertEqual(2, len(actual_outputs))
for i, actual, expected in zip(
itertools.count(), actual_outputs, expected_outputs
):
self.assertTrue(
torch.allclose(expected, actual, atol=1e-4, rtol=1e-4),
f"{i}th output: expected {expected}, actual {actual}",
)

if self.device == "cpu":
# CPU use different convolution implementation, skip the checks below
return

self.assertTrue(
actual_outputs[0].is_contiguous(memory_format=torch.contiguous_format)
)
self.assertTrue(
actual_outputs[1].is_contiguous(memory_format=torch.contiguous_format)
)

# we don't change the stride of y returned by forward. So there will
# be no extra copy
self.assertTrue(num_same_stride == 1, f"num_same_stride is {num_same_stride}")
# we changed the stride of self.conv(x) returned by forward. So there
# may be an extra copy
self.assertTrue(num_diff_stride == 1, f"num_diff_stride is {num_diff_stride}")


if HAS_CPU and not torch.backends.mps.is_available():

Expand Down
10 changes: 10 additions & 0 deletions test/inductor/test_layout_optim.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from torch import nn
from torch._dynamo.test_case import run_tests, TestCase
from torch._dynamo.utils import same
from torch._inductor import config
from torch.testing._internal.common_utils import TEST_WITH_ROCM
from torch.testing._internal.inductor_utils import HAS_CUDA

Expand Down Expand Up @@ -149,6 +150,7 @@ def get_example_inputs(self):

self.verify_accuracy_for_infer(Model)

@torch.no_grad()
def test_keep_output_layout_infer(self):
class Model(nn.Module):
def __init__(self):
Expand Down Expand Up @@ -178,6 +180,14 @@ def get_example_inputs(self):
# Note that if the output is channels last, the view op will fail.
opt_out.view(5, -1)

def test_keep_output_layout_with_freezing(self):
with config.patch(
{
"freezing": True,
}
):
self.test_keep_output_layout_infer()

def test_training_acc(self):
self.verify_accuracy_for_train(Model2Conv)

Expand Down
8 changes: 0 additions & 8 deletions torch/_functorch/aot_autograd.py
Original file line number Diff line number Diff line change
Expand Up @@ -2767,14 +2767,6 @@ def aot_dispatch_autograd(flat_fn, flat_args: List[Any], aot_config: AOTConfig,
torch._guards.TracingContext.get().fw_metadata = fw_metadata


# the compiler need to use this field to find the original modol outputs
# from the AOTAutograd fwd module's outputs. Thus compiler can make sure
# optimizations like layout optimization does not change those tensors'
# layout.
# TODO once https://github.com/pytorch/pytorch/pull/100652/files#r1212002707 is in
# change to access fw_metadata from the global tracing context.
fw_module.meta["original_output_start_index"] = fw_metadata.num_mutated_inputs

compiled_fw_func = aot_config.fw_compiler(
fw_module, adjusted_flat_args
)
Expand Down
61 changes: 48 additions & 13 deletions torch/_inductor/compile_fx.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import contextlib
import dataclasses
import functools
import itertools
Expand Down Expand Up @@ -232,6 +233,34 @@ def materialize(x):
return wrapper


def fake_tensor_prop(
gm: torch.fx.GraphModule,
example_inputs: List[torch.Tensor],
force_allow_non_fake_inputs=False,
):
"""
If we can not detect fake mode from the context of inputs, create one.
The created fake mode will be returned.
"""
fake_mode = detect_fake_mode(example_inputs)
if not fake_mode:
fake_mode = torch._subclasses.FakeTensorMode(allow_non_fake_inputs=True)
FakeTensorProp(gm, mode=fake_mode).propagate(*example_inputs)
else:
ctx = (
contextlib.nullcontext()
if not force_allow_non_fake_inputs
else unittest.mock.patch.object(fake_mode, "allow_non_fake_inputs", True)
)
with ctx:
FakeTensorProp(gm, mode=fake_mode).propagate_dont_convert_inputs(
*example_inputs
)

return fake_mode


@DebugContext.wrap
@torch.utils._python_dispatch._disable_current_modes()
@time_and_log(attr="compilation time (in seconds)")
Expand All @@ -247,6 +276,7 @@ def compile_fx_inner(
is_inference=False,
boxed_forward_device_index=None,
user_visible_outputs=frozenset(),
layout_opt=None,
):
if dynamo_utils.count_calls(gm.graph) == 0:
return make_boxed_func(gm.forward)
Expand All @@ -265,6 +295,7 @@ def compile_fx_inner(
"aot_mode": aot_mode,
"is_inference": is_inference,
"user_visible_outputs": user_visible_outputs,
"layout_opt": layout_opt,
}

compiled_graph: CompiledFxGraph = fx_codegen_and_compile(
Expand Down Expand Up @@ -403,6 +434,7 @@ def fx_codegen_and_compile(
aot_mode=False,
is_inference=False,
user_visible_outputs=frozenset(),
layout_opt=None,
):
if is_tf32_warning_applicable(gm):
_warn_tf32_disabled()
Expand Down Expand Up @@ -439,14 +471,8 @@ def fx_codegen_and_compile(
# .view() call.
view_to_reshape(gm)

fake_mode = detect_fake_mode(example_inputs)
if not fake_mode:
fake_mode = torch._subclasses.FakeTensorMode(allow_non_fake_inputs=True)
FakeTensorProp(gm, mode=fake_mode).propagate(*example_inputs)
else:
FakeTensorProp(gm, mode=fake_mode).propagate_dont_convert_inputs(
*example_inputs
)
fake_mode = fake_tensor_prop(gm, example_inputs)

# pattern matcher passes might not preserve striding information
# on node.meta["val"]. if in the future we rely on these being
# correct we will need to fix.
Expand Down Expand Up @@ -746,11 +772,17 @@ def fw_compiler_freezing(
graph_id,
forward_device,
):
from torch._inductor.freezing import freeze
from torch._inductor.freezing import convert_conv_weights_to_channels_last, freeze

# partition_fn won't be called
joint_graph_passes(aot_autograd_model)

layout_opt = GraphLowering.decide_layout_opt(aot_autograd_model)
if layout_opt:
# make sure meta['val'] is properly setup
fake_tensor_prop(aot_autograd_model, aot_example_inputs, True)
convert_conv_weights_to_channels_last(aot_autograd_model)

opt_model, preserved_arg_indices = freeze(
dynamo_model,
aot_autograd_model,
Expand All @@ -762,6 +794,11 @@ def fw_compiler_freezing(

fake_mode = detect_fake_mode(aot_example_inputs)

# for freezing, all graph outputs should be user visible
*_, model_outputs_node = opt_model.graph.nodes
model_outputs = model_outputs_node.args[0]
user_visible_outputs = [n.name for n in model_outputs]

# constant params will be real tensors, not fake
with unittest.mock.patch.object(fake_mode, "allow_non_fake_inputs", True):
optimized_function = inner_compile(
Expand All @@ -772,6 +809,8 @@ def fw_compiler_freezing(
graph_id=graph_id,
is_inference=True,
boxed_forward_device_index=forward_device,
layout_opt=layout_opt,
user_visible_outputs=user_visible_outputs,
)

# Need to drop the args we have constant-ified.
Expand Down Expand Up @@ -898,12 +937,8 @@ def fw_compiler_base(model: torch.fx.GraphModule, example_inputs, is_inference):
orig_model_outputs_node.args
)
num_orig_model_outputs = len(orig_model_outputs)
original_output_start_index = model.meta.get(
"original_output_start_index", 0
)
else:
num_orig_model_outputs = num_model_outputs
original_output_start_index = 0

assert num_orig_model_outputs <= num_model_outputs

Expand Down

0 comments on commit 98f00f8

Please sign in to comment.