Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[inductor] convert layout of conv weight ahead of time for inference #103642

Closed
wants to merge 18 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
18 commits
Select commit Hold shift + click to select a range
3b03e06
[wip][inductor] convert layout of conv weight ahead of time
shunting314 Jun 15, 2023
ca24229
Update on "[wip][inductor] convert layout of conv weight ahead of tim…
shunting314 Jun 15, 2023
379fabe
Update on "[wip][inductor] convert layout of conv weight ahead of tim…
shunting314 Jun 15, 2023
53e9ed9
Update on "[wip][inductor] convert layout of conv weight ahead of tim…
shunting314 Jun 16, 2023
27b08b4
Update on "[wip][inductor] convert layout of conv weight ahead of tim…
shunting314 Jun 16, 2023
7c200c3
Update on "[inductor] convert layout of conv weight ahead of time for…
shunting314 Jun 16, 2023
0122768
Update on "[inductor] convert layout of conv weight ahead of time for…
shunting314 Jun 16, 2023
230405e
Update on "[inductor] convert layout of conv weight ahead of time for…
shunting314 Jun 22, 2023
dd3165b
Update on "[inductor] convert layout of conv weight ahead of time for…
shunting314 Jun 22, 2023
0e0d73e
Update on "[inductor] convert layout of conv weight ahead of time for…
shunting314 Jun 23, 2023
bea85f7
Update on "[inductor] convert layout of conv weight ahead of time for…
shunting314 Jun 23, 2023
6c3a3c6
Update on "[inductor] convert layout of conv weight ahead of time for…
shunting314 Jun 26, 2023
aeb658b
Update on "[inductor] convert layout of conv weight ahead of time for…
shunting314 Jun 27, 2023
d5509a3
Update on "[inductor] convert layout of conv weight ahead of time for…
shunting314 Jun 27, 2023
5d8c954
Update on "[inductor] convert layout of conv weight ahead of time for…
shunting314 Jun 27, 2023
92d213c
Update on "[inductor] convert layout of conv weight ahead of time for…
shunting314 Jun 28, 2023
24b118b
Update on "[inductor] convert layout of conv weight ahead of time for…
shunting314 Jun 28, 2023
b8c4556
Update on "[inductor] convert layout of conv weight ahead of time for…
shunting314 Jun 28, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
128 changes: 127 additions & 1 deletion test/inductor/test_inductor_freezing.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import contextlib
import functools
import importlib
import itertools
import os
import sys
import unittest
Expand All @@ -10,8 +11,9 @@
import torch

import torch._dynamo
from torch import nn
from torch._inductor import config
from torch._inductor.utils import run_and_get_code
from torch._inductor.utils import override_lowering, run_and_get_code
from torch.testing import FileCheck

# Make the helper files in test/ importable
Expand Down Expand Up @@ -43,6 +45,7 @@

HAS_MULTIGPU = HAS_CUDA and torch.cuda.device_count() >= 2
aten = torch.ops.aten
prims = torch.ops.prims
requires_cuda = functools.partial(unittest.skipIf, not HAS_CUDA, "requires cuda")


Expand Down Expand Up @@ -312,6 +315,129 @@ def foo(mod, inp):
self.assertEqual(eager, compiled)
self.assertTrue(weight_ref() is None)

def test_conv_weight_layout_convert(self):
class Model(torch.nn.Module):
def __init__(self):
super().__init__()
self.conv = nn.Conv2d(
3, 128, kernel_size=3, padding=1, stride=1, bias=False
)

def forward(self, x):
return self.conv(x)

@staticmethod
def get_example_inputs():
return (torch.rand(2, 3, 5, 5).to(self.device),)

from torch._inductor.compile_fx import compile_fx, compile_fx_inner

nconv = 0

def my_inner_compile(gm, example_inputs, *args, **kwargs):
shunting314 marked this conversation as resolved.
Show resolved Hide resolved
out = compile_fx_inner(gm, example_inputs, *args, **kwargs)

nonlocal nconv
convs = [n for n in gm.graph.nodes if n.target == aten.convolution.default]
nconv += len(convs)
for conv in convs:
weight_node = conv.args[1]
weight_const_tensor = getattr(gm, weight_node.target)
self.assertTrue(
weight_const_tensor.is_contiguous(memory_format=torch.channels_last)
)
self.assertTrue(
weight_node.meta["val"].is_contiguous(
memory_format=torch.channels_last
)
)

return out

mod = torch.compile(
Model().eval().to(self.device),
backend=functools.partial(compile_fx, inner_compile=my_inner_compile),
)
inp = mod.get_example_inputs()
with torch.no_grad():
mod(*inp)

# Only check the assertion for CUDA.
# For CPU, we may get torch.ops.mkldnn._convolution_pointwise.default
# in the joint graph rather than torch.ops.aten.convolution.default.
# Currently we only handle aten.convolution.default in layout
# optimization. That's why the count may be 0 here for CPU.
if self.device == "cuda":
self.assertTrue(nconv == 1)

def test_redundant_clone_for_layout_convert(self):
class Model(torch.nn.Module):
shunting314 marked this conversation as resolved.
Show resolved Hide resolved
def __init__(self):
super().__init__()
self.conv = nn.Conv2d(
3, 128, kernel_size=3, padding=1, stride=1, bias=False
)

def forward(self, x):
y = x + 1
return self.conv(x), y

@staticmethod
def get_example_inputs():
return (torch.rand(2, 3, 5, 5).to(self.device),)

mod = Model().eval().to(self.device)
inp = mod.get_example_inputs()
with torch.no_grad():
expected_outputs = mod(*inp)

num_same_stride = 0
num_diff_stride = 0

def debug_inductor_force_stride_order(orig_fn, input_tensor, stride):
nonlocal num_same_stride, num_diff_stride
input_tensor.realize()
if tuple(input_tensor.get_stride()) == tuple(stride):
num_same_stride += 1
else:
num_diff_stride += 1
return orig_fn(input_tensor, stride)

with override_lowering(
prims.inductor_force_stride_order.default, debug_inductor_force_stride_order
):
opt_mod = torch.compile(mod)
with torch.no_grad():
actual_outputs = opt_mod(*inp)

self.assertEqual(len(actual_outputs), len(expected_outputs))
self.assertEqual(2, len(actual_outputs))
for i, actual, expected in zip(
itertools.count(), actual_outputs, expected_outputs
):
self.assertTrue(
torch.allclose(expected, actual, atol=1e-4, rtol=1e-4),
f"{i}th output: expected {expected}, actual {actual}",
)

if self.device == "cpu":
# CPU use different convolution implementation, skip the checks below
return

self.assertTrue(
actual_outputs[0].is_contiguous(memory_format=torch.contiguous_format)
)
self.assertTrue(
actual_outputs[1].is_contiguous(memory_format=torch.contiguous_format)
)

# we don't change the stride of y returned by forward. So there will
# be no extra copy
self.assertTrue(num_same_stride == 1, f"num_same_stride is {num_same_stride}")
# we changed the stride of self.conv(x) returned by forward. So there
# may be an extra copy
self.assertTrue(num_diff_stride == 1, f"num_diff_stride is {num_diff_stride}")


if HAS_CPU and not torch.backends.mps.is_available():

Expand Down
10 changes: 10 additions & 0 deletions test/inductor/test_layout_optim.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from torch import nn
from torch._dynamo.test_case import run_tests, TestCase
from torch._dynamo.utils import same
from torch._inductor import config
from torch.testing._internal.common_utils import TEST_WITH_ROCM
from torch.testing._internal.inductor_utils import HAS_CUDA

Expand Down Expand Up @@ -149,6 +150,7 @@ def get_example_inputs(self):

self.verify_accuracy_for_infer(Model)

@torch.no_grad()
def test_keep_output_layout_infer(self):
class Model(nn.Module):
def __init__(self):
Expand Down Expand Up @@ -178,6 +180,14 @@ def get_example_inputs(self):
# Note that if the output is channels last, the view op will fail.
opt_out.view(5, -1)

def test_keep_output_layout_with_freezing(self):
with config.patch(
shunting314 marked this conversation as resolved.
Show resolved Hide resolved
{
"freezing": True,
}
):
self.test_keep_output_layout_infer()

def test_training_acc(self):
self.verify_accuracy_for_train(Model2Conv)

Expand Down
8 changes: 0 additions & 8 deletions torch/_functorch/aot_autograd.py
Original file line number Diff line number Diff line change
Expand Up @@ -2767,14 +2767,6 @@ def aot_dispatch_autograd(flat_fn, flat_args: List[Any], aot_config: AOTConfig,
torch._guards.TracingContext.get().fw_metadata = fw_metadata


# the compiler need to use this field to find the original modol outputs
# from the AOTAutograd fwd module's outputs. Thus compiler can make sure
# optimizations like layout optimization does not change those tensors'
# layout.
# TODO once https://github.com/pytorch/pytorch/pull/100652/files#r1212002707 is in
# change to access fw_metadata from the global tracing context.
fw_module.meta["original_output_start_index"] = fw_metadata.num_mutated_inputs

compiled_fw_func = aot_config.fw_compiler(
fw_module, adjusted_flat_args
)
Expand Down
61 changes: 48 additions & 13 deletions torch/_inductor/compile_fx.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import contextlib
import dataclasses
import functools
import itertools
Expand Down Expand Up @@ -232,6 +233,34 @@ def materialize(x):
return wrapper


def fake_tensor_prop(
gm: torch.fx.GraphModule,
example_inputs: List[torch.Tensor],
force_allow_non_fake_inputs=False,
):
"""
If we can not detect fake mode from the context of inputs, create one.

The created fake mode will be returned.
"""
fake_mode = detect_fake_mode(example_inputs)
if not fake_mode:
fake_mode = torch._subclasses.FakeTensorMode(allow_non_fake_inputs=True)
FakeTensorProp(gm, mode=fake_mode).propagate(*example_inputs)
else:
ctx = (
contextlib.nullcontext()
if not force_allow_non_fake_inputs
else unittest.mock.patch.object(fake_mode, "allow_non_fake_inputs", True)
)
with ctx:
FakeTensorProp(gm, mode=fake_mode).propagate_dont_convert_inputs(
*example_inputs
)

return fake_mode


@DebugContext.wrap
@torch.utils._python_dispatch._disable_current_modes()
@time_and_log(attr="compilation time (in seconds)")
Expand All @@ -247,6 +276,7 @@ def compile_fx_inner(
is_inference=False,
boxed_forward_device_index=None,
user_visible_outputs=frozenset(),
layout_opt=None,
):
if dynamo_utils.count_calls(gm.graph) == 0:
return make_boxed_func(gm.forward)
Expand All @@ -265,6 +295,7 @@ def compile_fx_inner(
"aot_mode": aot_mode,
"is_inference": is_inference,
"user_visible_outputs": user_visible_outputs,
"layout_opt": layout_opt,
}

compiled_graph: CompiledFxGraph = fx_codegen_and_compile(
Expand Down Expand Up @@ -403,6 +434,7 @@ def fx_codegen_and_compile(
aot_mode=False,
is_inference=False,
user_visible_outputs=frozenset(),
layout_opt=None,
):
if is_tf32_warning_applicable(gm):
_warn_tf32_disabled()
Expand Down Expand Up @@ -439,14 +471,8 @@ def fx_codegen_and_compile(
# .view() call.
view_to_reshape(gm)

fake_mode = detect_fake_mode(example_inputs)
if not fake_mode:
fake_mode = torch._subclasses.FakeTensorMode(allow_non_fake_inputs=True)
FakeTensorProp(gm, mode=fake_mode).propagate(*example_inputs)
else:
FakeTensorProp(gm, mode=fake_mode).propagate_dont_convert_inputs(
*example_inputs
)
fake_mode = fake_tensor_prop(gm, example_inputs)

# pattern matcher passes might not preserve striding information
# on node.meta["val"]. if in the future we rely on these being
# correct we will need to fix.
Expand Down Expand Up @@ -746,11 +772,17 @@ def fw_compiler_freezing(
graph_id,
forward_device,
):
from torch._inductor.freezing import freeze
from torch._inductor.freezing import convert_conv_weights_to_channels_last, freeze

# partition_fn won't be called
joint_graph_passes(aot_autograd_model)

layout_opt = GraphLowering.decide_layout_opt(aot_autograd_model)
if layout_opt:
# make sure meta['val'] is properly setup
fake_tensor_prop(aot_autograd_model, aot_example_inputs, True)
convert_conv_weights_to_channels_last(aot_autograd_model)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would have thought we needed the fake_tensor_prop after, but not before. I do think we need it after or the meta's will be incorrect

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Run a FakeTensorProp before the conversion to make sure meta['val'] is properly setup. I'm worried that they may not already be setup.

But if we are sure FakeTensorProp has already been called before, we can save this call.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should add it again, since now the existing meta's will be incorrect, since we've changed striding


shunting314 marked this conversation as resolved.
Show resolved Hide resolved
opt_model, preserved_arg_indices = freeze(
dynamo_model,
aot_autograd_model,
Expand All @@ -762,6 +794,11 @@ def fw_compiler_freezing(

fake_mode = detect_fake_mode(aot_example_inputs)

# for freezing, all graph outputs should be user visible
*_, model_outputs_node = opt_model.graph.nodes
model_outputs = model_outputs_node.args[0]
user_visible_outputs = [n.name for n in model_outputs]

# constant params will be real tensors, not fake
with unittest.mock.patch.object(fake_mode, "allow_non_fake_inputs", True):
optimized_function = inner_compile(
Expand All @@ -772,6 +809,8 @@ def fw_compiler_freezing(
graph_id=graph_id,
is_inference=True,
boxed_forward_device_index=forward_device,
layout_opt=layout_opt,
user_visible_outputs=user_visible_outputs,
)

# Need to drop the args we have constant-ified.
Expand Down Expand Up @@ -898,12 +937,8 @@ def fw_compiler_base(model: torch.fx.GraphModule, example_inputs, is_inference):
orig_model_outputs_node.args
)
num_orig_model_outputs = len(orig_model_outputs)
original_output_start_index = model.meta.get(
"original_output_start_index", 0
)
else:
num_orig_model_outputs = num_model_outputs
original_output_start_index = 0

assert num_orig_model_outputs <= num_model_outputs

Expand Down